Commit 812f6726 authored by pengcheng888's avatar pengcheng888
Browse files

issue/563 - 调整#include位置

parent ac4aae48
#ifndef _TOPKROUTER_KERNEL_CUH__ #ifndef _TOPKROUTER_KERNEL_CUH__
#define _TOPKROUTER_KERNEL_CUH__ #define _TOPKROUTER_KERNEL_CUH__
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
// #include <cuda_bf16.h>
// #include <cuda_fp16.h>
// #include <cuda_runtime.h>
template <typename T> template <typename T>
inline __device__ float exp_func(T x) { inline __device__ float exp_func(T x) {
......
#include "../../../devices/metax/metax_common.h" #include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h" #include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "topkrouter_metax.h" #include "topkrouter_metax.h"
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh> #include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include "../cuda/kernel.cuh"
namespace op::topkrouter::metax { namespace op::topkrouter::metax {
......
...@@ -2,9 +2,17 @@ ...@@ -2,9 +2,17 @@
#include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh" #include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "topkrouter_nvidia.cuh" #include "topkrouter_nvidia.cuh"
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh> #include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include "../cuda/kernel.cuh"
namespace op::topkrouter::nvidia { namespace op::topkrouter::nvidia {
......
...@@ -19,7 +19,7 @@ from libinfiniop import ( ...@@ -19,7 +19,7 @@ from libinfiniop import (
InfiniDtypeNames, InfiniDtypeNames,
InfiniDeviceNames, InfiniDeviceNames,
infiniopOperatorDescriptor_t, infiniopOperatorDescriptor_t,
torch_device_map torch_device_map,
) )
# ============================================================================== # ==============================================================================
...@@ -29,12 +29,14 @@ from libinfiniop import ( ...@@ -29,12 +29,14 @@ from libinfiniop import (
_TEST_CASES_ = [ _TEST_CASES_ = [
# x_shape, x_stride, topk, routed_scaling_factor # x_shape, x_stride, topk, routed_scaling_factor
((1, 256), None, 8, 2.5), ((1, 256), None, 8, 2.5),
((2, 256), None, 8, 1.0),
] ]
# w (weight) types # w (weight) types
# Note: 'None' means the same as input dtype # Note: 'None' means the same as input dtype
# _X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16] # _X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
_X_DTYPES = [] # CPU CI _X_DTYPES = [] # CPU CI
# x types used for testing # x types used for testing
_VALUE_DTYPES = [InfiniDtype.F32] _VALUE_DTYPES = [InfiniDtype.F32]
...@@ -57,17 +59,34 @@ NUM_ITERATIONS = 1000 ...@@ -57,17 +59,34 @@ NUM_ITERATIONS = 1000
def tensorInfo(data): def tensorInfo(data):
print("data: ", data.is_contiguous(), data.device, data.dtype, data.shape, data.stride(), data.data_ptr(), hex(data.data_ptr())) print(
"data: ",
data.is_contiguous(),
data.device,
data.dtype,
data.shape,
data.stride(),
data.data_ptr(),
hex(data.data_ptr()),
)
class DeepseekV3TopkRouter(nn.Module): class DeepseekV3TopkRouter(nn.Module):
def __init__(self, correction_bias, routed_scaling_factor: float = 2.5, topk: int = 8, config=None): def __init__(
self,
correction_bias,
routed_scaling_factor: float = 2.5,
topk: int = 8,
config=None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.top_k = topk # config.num_experts_per_tok 8 self.top_k = topk # config.num_experts_per_tok 8
assert topk == 8 assert topk == 8
self.n_routed_experts = 256 # config.n_routed_experts self.n_routed_experts = 256 # config.n_routed_experts
self.routed_scaling_factor = routed_scaling_factor # config.routed_scaling_factor 2.5 self.routed_scaling_factor = (
routed_scaling_factor # config.routed_scaling_factor 2.5
)
self.n_group = 8 # config.n_group self.n_group = 8 # config.n_group
self.topk_group = 4 # config.topk_group self.topk_group = 4 # config.topk_group
self.norm_topk_prob = True # config.norm_topk_prob self.norm_topk_prob = True # config.norm_topk_prob
...@@ -81,14 +100,20 @@ class DeepseekV3TopkRouter(nn.Module): ...@@ -81,14 +100,20 @@ class DeepseekV3TopkRouter(nn.Module):
@torch.no_grad() @torch.no_grad()
def get_topk_indices(self, scores): def get_topk_indices(self, scores):
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) # Size([1, 256]) scores_for_choice = scores.view(
-1, self.n_routed_experts
) + self.e_score_correction_bias.unsqueeze(0) # Size([1, 256])
group_scores = ( group_scores = (
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) scores_for_choice.view(
-1, self.n_group, self.n_routed_experts // self.n_group
)
.topk(2, dim=-1)[0] .topk(2, dim=-1)[0]
.sum(dim=-1) .sum(dim=-1)
) )
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[1] # Size([1, 4]) group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[
1
] # Size([1, 4])
group_mask = torch.zeros_like(group_scores) # Size([1, 8]) group_mask = torch.zeros_like(group_scores) # Size([1, 8])
group_mask.scatter_(1, group_idx, 1) # Size([1, 8]) group_mask.scatter_(1, group_idx, 1) # Size([1, 8])
...@@ -98,8 +123,12 @@ class DeepseekV3TopkRouter(nn.Module): ...@@ -98,8 +123,12 @@ class DeepseekV3TopkRouter(nn.Module):
.reshape(-1, self.n_routed_experts) .reshape(-1, self.n_routed_experts)
) )
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # Size([1, 256]) scores_for_choice = scores_for_choice.masked_fill(
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=True)[1] # Size([1, 8]) ~score_mask.bool(), 0.0
) # Size([1, 256])
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=True)[
1
] # Size([1, 8])
return topk_indices return topk_indices
...@@ -122,7 +151,9 @@ class DeepseekV3TopkRouter(nn.Module): ...@@ -122,7 +151,9 @@ class DeepseekV3TopkRouter(nn.Module):
def torch_topkrouter(router_logits, correction_bias, routed_scaling_factor, topk): def torch_topkrouter(router_logits, correction_bias, routed_scaling_factor, topk):
lable_indices, lable_values = DeepseekV3TopkRouter(correction_bias, routed_scaling_factor, topk)(router_logits) lable_indices, lable_values = DeepseekV3TopkRouter(
correction_bias, routed_scaling_factor, topk
)(router_logits)
lable_indices = lable_indices.to(torch.int32) lable_indices = lable_indices.to(torch.int32)
return lable_values, lable_indices return lable_values, lable_indices
...@@ -146,8 +177,12 @@ def test( ...@@ -146,8 +177,12 @@ def test(
data = torch.arange(0, x_shape[0] * x_shape[1]).reshape(x_shape) data = torch.arange(0, x_shape[0] * x_shape[1]).reshape(x_shape)
N, width = x_shape N, width = x_shape
x = TestTensor(x_shape, data.stride(), x_dtype, device, scale=5.0, bias=-5.0, mode="random") x = TestTensor(
correction_bias = TestTensor([x_shape[1]], [1], InfiniDtype.F32, device, mode="random") x_shape, data.stride(), x_dtype, device, scale=5.0, bias=-5.0, mode="random"
)
correction_bias = TestTensor(
[x_shape[1]], [1], InfiniDtype.F32, device, mode="random"
)
if sync is not None: if sync is not None:
sync() sync()
...@@ -155,10 +190,7 @@ def test( ...@@ -155,10 +190,7 @@ def test(
descriptor = infiniopOperatorDescriptor_t() descriptor = infiniopOperatorDescriptor_t()
check_error( check_error(
LIBINFINIOP.infiniopCreateTopkrouterDescriptor( LIBINFINIOP.infiniopCreateTopkrouterDescriptor(
handle, handle, ctypes.byref(descriptor), x.descriptor, correction_bias.descriptor
ctypes.byref(descriptor),
x.descriptor,
correction_bias.descriptor
) )
) )
...@@ -174,8 +206,12 @@ def test( ...@@ -174,8 +206,12 @@ def test(
) )
workspace = TestWorkspace(workspace_size.value, x.device) workspace = TestWorkspace(workspace_size.value, x.device)
values = torch.zeros((N, topk), dtype=torch.float32, device=torch_device_map[x.device]) values = torch.zeros(
indices = torch.zeros((N, topk), dtype=torch.int32, device=torch_device_map[x.device]) (N, topk), dtype=torch.float32, device=torch_device_map[x.device]
)
indices = torch.zeros(
(N, topk), dtype=torch.int32, device=torch_device_map[x.device]
)
def lib_topkrouter(): def lib_topkrouter():
check_error( check_error(
...@@ -195,8 +231,9 @@ def test( ...@@ -195,8 +231,9 @@ def test(
lib_topkrouter() lib_topkrouter()
lable_values, lable_indices = torch_topkrouter(
lable_values, lable_indices = torch_topkrouter(x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk) x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk
)
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(lable_values, values, atol=atol, rtol=rtol) debug(lable_values, values, atol=atol, rtol=rtol)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment