Commit 812f6726 authored by pengcheng888's avatar pengcheng888
Browse files

issue/563 - 调整#include位置

parent ac4aae48
#ifndef _TOPKROUTER_KERNEL_CUH__
#define _TOPKROUTER_KERNEL_CUH__
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
// #include <cuda_bf16.h>
// #include <cuda_fp16.h>
// #include <cuda_runtime.h>
template <typename T>
inline __device__ float exp_func(T x) {
......
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "topkrouter_metax.h"
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include "../cuda/kernel.cuh"
namespace op::topkrouter::metax {
......
......@@ -2,9 +2,17 @@
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "topkrouter_nvidia.cuh"
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include "../cuda/kernel.cuh"
namespace op::topkrouter::nvidia {
......
......@@ -19,7 +19,7 @@ from libinfiniop import (
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
torch_device_map
torch_device_map,
)
# ==============================================================================
......@@ -29,12 +29,14 @@ from libinfiniop import (
_TEST_CASES_ = [
# x_shape, x_stride, topk, routed_scaling_factor
((1, 256), None, 8, 2.5),
((2, 256), None, 8, 1.0),
]
# w (weight) types
# Note: 'None' means the same as input dtype
# _X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
_X_DTYPES = [] # CPU CI
_X_DTYPES = [] # CPU CI
# x types used for testing
_VALUE_DTYPES = [InfiniDtype.F32]
......@@ -57,17 +59,34 @@ NUM_ITERATIONS = 1000
def tensorInfo(data):
print("data: ", data.is_contiguous(), data.device, data.dtype, data.shape, data.stride(), data.data_ptr(), hex(data.data_ptr()))
print(
"data: ",
data.is_contiguous(),
data.device,
data.dtype,
data.shape,
data.stride(),
data.data_ptr(),
hex(data.data_ptr()),
)
class DeepseekV3TopkRouter(nn.Module):
def __init__(self, correction_bias, routed_scaling_factor: float = 2.5, topk: int = 8, config=None):
def __init__(
self,
correction_bias,
routed_scaling_factor: float = 2.5,
topk: int = 8,
config=None,
):
super().__init__()
self.config = config
self.top_k = topk # config.num_experts_per_tok 8
assert topk == 8
self.n_routed_experts = 256 # config.n_routed_experts
self.routed_scaling_factor = routed_scaling_factor # config.routed_scaling_factor 2.5
self.routed_scaling_factor = (
routed_scaling_factor # config.routed_scaling_factor 2.5
)
self.n_group = 8 # config.n_group
self.topk_group = 4 # config.topk_group
self.norm_topk_prob = True # config.norm_topk_prob
......@@ -81,14 +100,20 @@ class DeepseekV3TopkRouter(nn.Module):
@torch.no_grad()
def get_topk_indices(self, scores):
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) # Size([1, 256])
scores_for_choice = scores.view(
-1, self.n_routed_experts
) + self.e_score_correction_bias.unsqueeze(0) # Size([1, 256])
group_scores = (
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
scores_for_choice.view(
-1, self.n_group, self.n_routed_experts // self.n_group
)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[1] # Size([1, 4])
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[
1
] # Size([1, 4])
group_mask = torch.zeros_like(group_scores) # Size([1, 8])
group_mask.scatter_(1, group_idx, 1) # Size([1, 8])
......@@ -98,8 +123,12 @@ class DeepseekV3TopkRouter(nn.Module):
.reshape(-1, self.n_routed_experts)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # Size([1, 256])
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=True)[1] # Size([1, 8])
scores_for_choice = scores_for_choice.masked_fill(
~score_mask.bool(), 0.0
) # Size([1, 256])
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=True)[
1
] # Size([1, 8])
return topk_indices
......@@ -122,21 +151,23 @@ class DeepseekV3TopkRouter(nn.Module):
def torch_topkrouter(router_logits, correction_bias, routed_scaling_factor, topk):
lable_indices, lable_values = DeepseekV3TopkRouter(correction_bias, routed_scaling_factor, topk)(router_logits)
lable_indices, lable_values = DeepseekV3TopkRouter(
correction_bias, routed_scaling_factor, topk
)(router_logits)
lable_indices = lable_indices.to(torch.int32)
return lable_values, lable_indices
def test(
handle,
device,
x_shape,
x_stride,
topk,
routed_scaling_factor,
x_dtype=InfiniDtype.F32,
dtype=InfiniDtype.F16,
sync=None,
handle,
device,
x_shape,
x_stride,
topk,
routed_scaling_factor,
x_dtype=InfiniDtype.F32,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing topkrouter on {InfiniDeviceNames[device]} with x_shape:{x_shape} "
......@@ -146,8 +177,12 @@ def test(
data = torch.arange(0, x_shape[0] * x_shape[1]).reshape(x_shape)
N, width = x_shape
x = TestTensor(x_shape, data.stride(), x_dtype, device, scale=5.0, bias=-5.0, mode="random")
correction_bias = TestTensor([x_shape[1]], [1], InfiniDtype.F32, device, mode="random")
x = TestTensor(
x_shape, data.stride(), x_dtype, device, scale=5.0, bias=-5.0, mode="random"
)
correction_bias = TestTensor(
[x_shape[1]], [1], InfiniDtype.F32, device, mode="random"
)
if sync is not None:
sync()
......@@ -155,10 +190,7 @@ def test(
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreateTopkrouterDescriptor(
handle,
ctypes.byref(descriptor),
x.descriptor,
correction_bias.descriptor
handle, ctypes.byref(descriptor), x.descriptor, correction_bias.descriptor
)
)
......@@ -174,8 +206,12 @@ def test(
)
workspace = TestWorkspace(workspace_size.value, x.device)
values = torch.zeros((N, topk), dtype=torch.float32, device=torch_device_map[x.device])
indices = torch.zeros((N, topk), dtype=torch.int32, device=torch_device_map[x.device])
values = torch.zeros(
(N, topk), dtype=torch.float32, device=torch_device_map[x.device]
)
indices = torch.zeros(
(N, topk), dtype=torch.int32, device=torch_device_map[x.device]
)
def lib_topkrouter():
check_error(
......@@ -195,8 +231,9 @@ def test(
lib_topkrouter()
lable_values, lable_indices = torch_topkrouter(x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk)
lable_values, lable_indices = torch_topkrouter(
x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk
)
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(lable_values, values, atol=atol, rtol=rtol)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment