Unverified Commit 215d1932 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #814 from InfiniTensor/issue/563_pr

Issue/563 Add metax support for topkrouter
parents 73b61c50 812f6726
#ifndef _TOPKROUTER_KERNEL_CUH__ #ifndef _TOPKROUTER_KERNEL_CUH__
#define _TOPKROUTER_KERNEL_CUH__ #define _TOPKROUTER_KERNEL_CUH__
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
template <typename T> template <typename T>
inline __device__ float exp_func(T x) { inline __device__ float exp_func(T x) {
float data; float data;
if constexpr (std::is_same_v<T, float>) { if constexpr (std::is_same_v<T, float>) {
data = x; data = x;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) { } else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
data = __bfloat162float(x); data = __bfloat162float(x);
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half>) {
data = __half2float(x); data = __half2float(x);
......
#ifndef __TOPKROUTER_METAX_H__
#define __TOPKROUTER_METAX_H__
#include "../topkrouter.h"
DESCRIPTOR(metax)
#endif
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "topkrouter_metax.h"
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include "../cuda/kernel.cuh"
namespace op::topkrouter::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t correction_bias_desc) {
auto result = TopkrouterInfo::create(x_desc);
CHECK_RESULT(result);
auto info = result.take();
if (info.x_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
template <int BLOCK_SIZE = 128>
infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias,
const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype,
hcStream_t stream) {
const int block_threads = BLOCK_SIZE;
dim3 blocks(N);
dim3 threads(block_threads);
if (xtype == INFINI_DTYPE_F32) {
topkrouter_kernel<float, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (float *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else if (xtype == INFINI_DTYPE_F16) {
topkrouter_kernel<half, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (half *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else if (xtype == INFINI_DTYPE_BF16) {
topkrouter_kernel<cuda_bfloat16, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (cuda_bfloat16 *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
}; // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
float *values,
int *indices,
const void *x,
const float *correction_bias,
const float routed_scaling_factor,
const size_t topk,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
size_t N = _info.N;
size_t width = _info.width; // 256
// size_t n_routed_experts = 256;
// size_t n_group = 8;
// size_t topk_group = 4;
auto cuda_stream = reinterpret_cast<hcStream_t>(stream);
if (256 == width) {
launch_topkrouter<256>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream);
} else {
return INFINI_STATUS_BAD_PARAM;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::topkrouter::metax
...@@ -2,9 +2,17 @@ ...@@ -2,9 +2,17 @@
#include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh" #include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "topkrouter_nvidia.cuh" #include "topkrouter_nvidia.cuh"
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh> #include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include "../cuda/kernel.cuh"
namespace op::topkrouter::nvidia { namespace op::topkrouter::nvidia {
......
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/topkrouter_nvidia.cuh" #include "nvidia/topkrouter_nvidia.cuh"
#endif #endif
#ifdef ENABLE_METAX_API
#include "metax/topkrouter_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
#include "kunlun/topkrouter_kunlun.h" #include "kunlun/topkrouter_kunlun.h"
#endif #endif
...@@ -30,6 +33,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i ...@@ -30,6 +33,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia); CREATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun); CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
...@@ -56,6 +62,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript ...@@ -56,6 +62,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia); GET(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun); GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
...@@ -85,6 +94,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void ...@@ -85,6 +94,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia); CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
...@@ -111,6 +123,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip ...@@ -111,6 +123,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia); DESTROY(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
DESTROY(INFINI_DEVICE_KUNLUN, kunlun); DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
......
...@@ -19,7 +19,7 @@ from libinfiniop import ( ...@@ -19,7 +19,7 @@ from libinfiniop import (
InfiniDtypeNames, InfiniDtypeNames,
InfiniDeviceNames, InfiniDeviceNames,
infiniopOperatorDescriptor_t, infiniopOperatorDescriptor_t,
torch_device_map torch_device_map,
) )
# ============================================================================== # ==============================================================================
...@@ -29,12 +29,14 @@ from libinfiniop import ( ...@@ -29,12 +29,14 @@ from libinfiniop import (
_TEST_CASES_ = [ _TEST_CASES_ = [
# x_shape, x_stride, topk, routed_scaling_factor # x_shape, x_stride, topk, routed_scaling_factor
((1, 256), None, 8, 2.5), ((1, 256), None, 8, 2.5),
((2, 256), None, 8, 1.0),
] ]
# w (weight) types # w (weight) types
# Note: 'None' means the same as input dtype # Note: 'None' means the same as input dtype
# _X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16] # _X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
_X_DTYPES = [] # CPU CI _X_DTYPES = [] # CPU CI
# x types used for testing # x types used for testing
_VALUE_DTYPES = [InfiniDtype.F32] _VALUE_DTYPES = [InfiniDtype.F32]
...@@ -57,17 +59,34 @@ NUM_ITERATIONS = 1000 ...@@ -57,17 +59,34 @@ NUM_ITERATIONS = 1000
def tensorInfo(data): def tensorInfo(data):
print("data: ", data.is_contiguous(), data.device, data.dtype, data.shape, data.stride(), data.data_ptr(), hex(data.data_ptr())) print(
"data: ",
data.is_contiguous(),
data.device,
data.dtype,
data.shape,
data.stride(),
data.data_ptr(),
hex(data.data_ptr()),
)
class DeepseekV3TopkRouter(nn.Module): class DeepseekV3TopkRouter(nn.Module):
def __init__(self, correction_bias, routed_scaling_factor: float = 2.5, topk: int = 8, config=None): def __init__(
self,
correction_bias,
routed_scaling_factor: float = 2.5,
topk: int = 8,
config=None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.top_k = topk # config.num_experts_per_tok 8 self.top_k = topk # config.num_experts_per_tok 8
assert topk == 8 assert topk == 8
self.n_routed_experts = 256 # config.n_routed_experts self.n_routed_experts = 256 # config.n_routed_experts
self.routed_scaling_factor = routed_scaling_factor # config.routed_scaling_factor 2.5 self.routed_scaling_factor = (
routed_scaling_factor # config.routed_scaling_factor 2.5
)
self.n_group = 8 # config.n_group self.n_group = 8 # config.n_group
self.topk_group = 4 # config.topk_group self.topk_group = 4 # config.topk_group
self.norm_topk_prob = True # config.norm_topk_prob self.norm_topk_prob = True # config.norm_topk_prob
...@@ -81,14 +100,20 @@ class DeepseekV3TopkRouter(nn.Module): ...@@ -81,14 +100,20 @@ class DeepseekV3TopkRouter(nn.Module):
@torch.no_grad() @torch.no_grad()
def get_topk_indices(self, scores): def get_topk_indices(self, scores):
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) # Size([1, 256]) scores_for_choice = scores.view(
-1, self.n_routed_experts
) + self.e_score_correction_bias.unsqueeze(0) # Size([1, 256])
group_scores = ( group_scores = (
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) scores_for_choice.view(
-1, self.n_group, self.n_routed_experts // self.n_group
)
.topk(2, dim=-1)[0] .topk(2, dim=-1)[0]
.sum(dim=-1) .sum(dim=-1)
) )
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[1] # Size([1, 4]) group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[
1
] # Size([1, 4])
group_mask = torch.zeros_like(group_scores) # Size([1, 8]) group_mask = torch.zeros_like(group_scores) # Size([1, 8])
group_mask.scatter_(1, group_idx, 1) # Size([1, 8]) group_mask.scatter_(1, group_idx, 1) # Size([1, 8])
...@@ -98,8 +123,12 @@ class DeepseekV3TopkRouter(nn.Module): ...@@ -98,8 +123,12 @@ class DeepseekV3TopkRouter(nn.Module):
.reshape(-1, self.n_routed_experts) .reshape(-1, self.n_routed_experts)
) )
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # Size([1, 256]) scores_for_choice = scores_for_choice.masked_fill(
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=True)[1] # Size([1, 8]) ~score_mask.bool(), 0.0
) # Size([1, 256])
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=True)[
1
] # Size([1, 8])
return topk_indices return topk_indices
...@@ -122,21 +151,23 @@ class DeepseekV3TopkRouter(nn.Module): ...@@ -122,21 +151,23 @@ class DeepseekV3TopkRouter(nn.Module):
def torch_topkrouter(router_logits, correction_bias, routed_scaling_factor, topk): def torch_topkrouter(router_logits, correction_bias, routed_scaling_factor, topk):
lable_indices, lable_values = DeepseekV3TopkRouter(correction_bias, routed_scaling_factor, topk)(router_logits) lable_indices, lable_values = DeepseekV3TopkRouter(
correction_bias, routed_scaling_factor, topk
)(router_logits)
lable_indices = lable_indices.to(torch.int32) lable_indices = lable_indices.to(torch.int32)
return lable_values, lable_indices return lable_values, lable_indices
def test( def test(
handle, handle,
device, device,
x_shape, x_shape,
x_stride, x_stride,
topk, topk,
routed_scaling_factor, routed_scaling_factor,
x_dtype=InfiniDtype.F32, x_dtype=InfiniDtype.F32,
dtype=InfiniDtype.F16, dtype=InfiniDtype.F16,
sync=None, sync=None,
): ):
print( print(
f"Testing topkrouter on {InfiniDeviceNames[device]} with x_shape:{x_shape} " f"Testing topkrouter on {InfiniDeviceNames[device]} with x_shape:{x_shape} "
...@@ -146,8 +177,12 @@ def test( ...@@ -146,8 +177,12 @@ def test(
data = torch.arange(0, x_shape[0] * x_shape[1]).reshape(x_shape) data = torch.arange(0, x_shape[0] * x_shape[1]).reshape(x_shape)
N, width = x_shape N, width = x_shape
x = TestTensor(x_shape, data.stride(), x_dtype, device, scale=5.0, bias=-5.0, mode="random") x = TestTensor(
correction_bias = TestTensor([x_shape[1]], [1], InfiniDtype.F32, device, mode="random") x_shape, data.stride(), x_dtype, device, scale=5.0, bias=-5.0, mode="random"
)
correction_bias = TestTensor(
[x_shape[1]], [1], InfiniDtype.F32, device, mode="random"
)
if sync is not None: if sync is not None:
sync() sync()
...@@ -155,10 +190,7 @@ def test( ...@@ -155,10 +190,7 @@ def test(
descriptor = infiniopOperatorDescriptor_t() descriptor = infiniopOperatorDescriptor_t()
check_error( check_error(
LIBINFINIOP.infiniopCreateTopkrouterDescriptor( LIBINFINIOP.infiniopCreateTopkrouterDescriptor(
handle, handle, ctypes.byref(descriptor), x.descriptor, correction_bias.descriptor
ctypes.byref(descriptor),
x.descriptor,
correction_bias.descriptor
) )
) )
...@@ -174,8 +206,12 @@ def test( ...@@ -174,8 +206,12 @@ def test(
) )
workspace = TestWorkspace(workspace_size.value, x.device) workspace = TestWorkspace(workspace_size.value, x.device)
values = torch.zeros((N, topk), dtype=torch.float32, device=torch_device_map[x.device]) values = torch.zeros(
indices = torch.zeros((N, topk), dtype=torch.int32, device=torch_device_map[x.device]) (N, topk), dtype=torch.float32, device=torch_device_map[x.device]
)
indices = torch.zeros(
(N, topk), dtype=torch.int32, device=torch_device_map[x.device]
)
def lib_topkrouter(): def lib_topkrouter():
check_error( check_error(
...@@ -195,8 +231,9 @@ def test( ...@@ -195,8 +231,9 @@ def test(
lib_topkrouter() lib_topkrouter()
lable_values, lable_indices = torch_topkrouter(
lable_values, lable_indices = torch_topkrouter(x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk) x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk
)
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(lable_values, values, atol=atol, rtol=rtol) debug(lable_values, values, atol=atol, rtol=rtol)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment