Commit a15aa367 authored by Zhao Shijie's avatar Zhao Shijie Committed by zhaoshijie
Browse files

issue/563 Add metax support for topkrouter

parent cad2d45a
......@@ -6,16 +6,16 @@
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
// #include <cuda_bf16.h>
// #include <cuda_fp16.h>
// #include <cuda_runtime.h>
template <typename T>
inline __device__ float exp_func(T x) {
float data;
if constexpr (std::is_same_v<T, float>) {
data = x;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
data = __bfloat162float(x);
} else if constexpr (std::is_same_v<T, half>) {
data = __half2float(x);
......
#ifndef __TOPKROUTER_METAX_H__
#define __TOPKROUTER_METAX_H__
#include "../topkrouter.h"
DESCRIPTOR(metax)
#endif
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "topkrouter_metax.h"
#include <cub/block/block_reduce.cuh>
namespace op::topkrouter::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t correction_bias_desc) {
auto result = TopkrouterInfo::create(x_desc);
CHECK_RESULT(result);
auto info = result.take();
if (info.x_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
template <int BLOCK_SIZE = 128>
infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias,
const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype,
hcStream_t stream) {
const int block_threads = BLOCK_SIZE;
dim3 blocks(N);
dim3 threads(block_threads);
if (xtype == INFINI_DTYPE_F32) {
topkrouter_kernel<float, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (float *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else if (xtype == INFINI_DTYPE_F16) {
topkrouter_kernel<half, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (half *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else if (xtype == INFINI_DTYPE_BF16) {
topkrouter_kernel<cuda_bfloat16, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (cuda_bfloat16 *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
}; // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
float *values,
int *indices,
const void *x,
const float *correction_bias,
const float routed_scaling_factor,
const size_t topk,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
size_t N = _info.N;
size_t width = _info.width; // 256
// size_t n_routed_experts = 256;
// size_t n_group = 8;
// size_t topk_group = 4;
auto cuda_stream = reinterpret_cast<hcStream_t>(stream);
if (256 == width) {
launch_topkrouter<256>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream);
} else {
return INFINI_STATUS_BAD_PARAM;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::topkrouter::metax
......@@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/topkrouter_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/topkrouter_metax.h"
#endif
__C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, infiniopTopkrouterDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t x_desc,
......@@ -26,6 +29,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
}
......@@ -49,6 +55,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
}
......@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
}
......@@ -98,6 +110,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
#endif
#ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax);
#endif
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment