Commit a15aa367 authored by Zhao Shijie's avatar Zhao Shijie Committed by zhaoshijie
Browse files

issue/563 Add metax support for topkrouter

parent cad2d45a
...@@ -6,16 +6,16 @@ ...@@ -6,16 +6,16 @@
#include <cub/block/block_reduce.cuh> #include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh> #include <cub/block/block_store.cuh>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <cuda_bf16.h> // #include <cuda_bf16.h>
#include <cuda_fp16.h> // #include <cuda_fp16.h>
#include <cuda_runtime.h> // #include <cuda_runtime.h>
template <typename T> template <typename T>
inline __device__ float exp_func(T x) { inline __device__ float exp_func(T x) {
float data; float data;
if constexpr (std::is_same_v<T, float>) { if constexpr (std::is_same_v<T, float>) {
data = x; data = x;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) { } else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
data = __bfloat162float(x); data = __bfloat162float(x);
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half>) {
data = __half2float(x); data = __half2float(x);
......
#ifndef __TOPKROUTER_METAX_H__
#define __TOPKROUTER_METAX_H__
#include "../topkrouter.h"
DESCRIPTOR(metax)
#endif
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "topkrouter_metax.h"
#include <cub/block/block_reduce.cuh>
namespace op::topkrouter::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t correction_bias_desc) {
auto result = TopkrouterInfo::create(x_desc);
CHECK_RESULT(result);
auto info = result.take();
if (info.x_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
template <int BLOCK_SIZE = 128>
infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias,
const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype,
hcStream_t stream) {
const int block_threads = BLOCK_SIZE;
dim3 blocks(N);
dim3 threads(block_threads);
if (xtype == INFINI_DTYPE_F32) {
topkrouter_kernel<float, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (float *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else if (xtype == INFINI_DTYPE_F16) {
topkrouter_kernel<half, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (half *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else if (xtype == INFINI_DTYPE_BF16) {
topkrouter_kernel<cuda_bfloat16, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (cuda_bfloat16 *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
}; // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
float *values,
int *indices,
const void *x,
const float *correction_bias,
const float routed_scaling_factor,
const size_t topk,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
size_t N = _info.N;
size_t width = _info.width; // 256
// size_t n_routed_experts = 256;
// size_t n_group = 8;
// size_t topk_group = 4;
auto cuda_stream = reinterpret_cast<hcStream_t>(stream);
if (256 == width) {
launch_topkrouter<256>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream);
} else {
return INFINI_STATUS_BAD_PARAM;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::topkrouter::metax
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/topkrouter_nvidia.cuh" #include "nvidia/topkrouter_nvidia.cuh"
#endif #endif
#ifdef ENABLE_METAX_API
#include "metax/topkrouter_metax.h"
#endif
__C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, infiniopTopkrouterDescriptor_t *desc_ptr, __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, infiniopTopkrouterDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t x_desc,
...@@ -26,6 +29,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i ...@@ -26,6 +29,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
#endif #endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia); CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
} }
...@@ -49,6 +55,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript ...@@ -49,6 +55,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
#endif #endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia); GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif #endif
} }
...@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void ...@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
#endif #endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia); CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif #endif
} }
...@@ -98,6 +110,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip ...@@ -98,6 +110,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
#endif #endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia); DESTROY(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax);
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment