Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
812f6726
Commit
812f6726
authored
Dec 19, 2025
by
pengcheng888
Browse files
issue/563 - 调整#include位置
parent
ac4aae48
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
85 additions
and
41 deletions
+85
-41
src/infiniop/ops/topkrouter/cuda/kernel.cuh
src/infiniop/ops/topkrouter/cuda/kernel.cuh
+0
-9
src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca
src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca
+9
-1
src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu
src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu
+9
-1
test/infiniop/topkrouter.py
test/infiniop/topkrouter.py
+67
-30
No files found.
src/infiniop/ops/topkrouter/cuda/kernel.cuh
View file @
812f6726
#ifndef _TOPKROUTER_KERNEL_CUH__
#ifndef _TOPKROUTER_KERNEL_CUH__
#define _TOPKROUTER_KERNEL_CUH__
#define _TOPKROUTER_KERNEL_CUH__
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
// #include <cuda_bf16.h>
// #include <cuda_fp16.h>
// #include <cuda_runtime.h>
template
<
typename
T
>
template
<
typename
T
>
inline
__device__
float
exp_func
(
T
x
)
{
inline
__device__
float
exp_func
(
T
x
)
{
...
...
src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca
View file @
812f6726
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "topkrouter_metax.h"
#include "topkrouter_metax.h"
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include "../cuda/kernel.cuh"
namespace op::topkrouter::metax {
namespace op::topkrouter::metax {
...
...
src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu
View file @
812f6726
...
@@ -2,9 +2,17 @@
...
@@ -2,9 +2,17 @@
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "topkrouter_nvidia.cuh"
#include "topkrouter_nvidia.cuh"
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include "../cuda/kernel.cuh"
namespace
op
::
topkrouter
::
nvidia
{
namespace
op
::
topkrouter
::
nvidia
{
...
...
test/infiniop/topkrouter.py
View file @
812f6726
...
@@ -19,7 +19,7 @@ from libinfiniop import (
...
@@ -19,7 +19,7 @@ from libinfiniop import (
InfiniDtypeNames
,
InfiniDtypeNames
,
InfiniDeviceNames
,
InfiniDeviceNames
,
infiniopOperatorDescriptor_t
,
infiniopOperatorDescriptor_t
,
torch_device_map
torch_device_map
,
)
)
# ==============================================================================
# ==============================================================================
...
@@ -29,12 +29,14 @@ from libinfiniop import (
...
@@ -29,12 +29,14 @@ from libinfiniop import (
_TEST_CASES_
=
[
_TEST_CASES_
=
[
# x_shape, x_stride, topk, routed_scaling_factor
# x_shape, x_stride, topk, routed_scaling_factor
((
1
,
256
),
None
,
8
,
2.5
),
((
1
,
256
),
None
,
8
,
2.5
),
((
2
,
256
),
None
,
8
,
1.0
),
]
]
# w (weight) types
# w (weight) types
# Note: 'None' means the same as input dtype
# Note: 'None' means the same as input dtype
# _X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
# _X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
_X_DTYPES
=
[]
# CPU CI
_X_DTYPES
=
[]
# CPU CI
# x types used for testing
# x types used for testing
_VALUE_DTYPES
=
[
InfiniDtype
.
F32
]
_VALUE_DTYPES
=
[
InfiniDtype
.
F32
]
...
@@ -57,17 +59,34 @@ NUM_ITERATIONS = 1000
...
@@ -57,17 +59,34 @@ NUM_ITERATIONS = 1000
def
tensorInfo
(
data
):
def
tensorInfo
(
data
):
print
(
"data: "
,
data
.
is_contiguous
(),
data
.
device
,
data
.
dtype
,
data
.
shape
,
data
.
stride
(),
data
.
data_ptr
(),
hex
(
data
.
data_ptr
()))
print
(
"data: "
,
data
.
is_contiguous
(),
data
.
device
,
data
.
dtype
,
data
.
shape
,
data
.
stride
(),
data
.
data_ptr
(),
hex
(
data
.
data_ptr
()),
)
class
DeepseekV3TopkRouter
(
nn
.
Module
):
class
DeepseekV3TopkRouter
(
nn
.
Module
):
def
__init__
(
self
,
correction_bias
,
routed_scaling_factor
:
float
=
2.5
,
topk
:
int
=
8
,
config
=
None
):
def
__init__
(
self
,
correction_bias
,
routed_scaling_factor
:
float
=
2.5
,
topk
:
int
=
8
,
config
=
None
,
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
top_k
=
topk
# config.num_experts_per_tok 8
self
.
top_k
=
topk
# config.num_experts_per_tok 8
assert
topk
==
8
assert
topk
==
8
self
.
n_routed_experts
=
256
# config.n_routed_experts
self
.
n_routed_experts
=
256
# config.n_routed_experts
self
.
routed_scaling_factor
=
routed_scaling_factor
# config.routed_scaling_factor 2.5
self
.
routed_scaling_factor
=
(
routed_scaling_factor
# config.routed_scaling_factor 2.5
)
self
.
n_group
=
8
# config.n_group
self
.
n_group
=
8
# config.n_group
self
.
topk_group
=
4
# config.topk_group
self
.
topk_group
=
4
# config.topk_group
self
.
norm_topk_prob
=
True
# config.norm_topk_prob
self
.
norm_topk_prob
=
True
# config.norm_topk_prob
...
@@ -81,14 +100,20 @@ class DeepseekV3TopkRouter(nn.Module):
...
@@ -81,14 +100,20 @@ class DeepseekV3TopkRouter(nn.Module):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
get_topk_indices
(
self
,
scores
):
def
get_topk_indices
(
self
,
scores
):
scores_for_choice
=
scores
.
view
(
-
1
,
self
.
n_routed_experts
)
+
self
.
e_score_correction_bias
.
unsqueeze
(
0
)
# Size([1, 256])
scores_for_choice
=
scores
.
view
(
-
1
,
self
.
n_routed_experts
)
+
self
.
e_score_correction_bias
.
unsqueeze
(
0
)
# Size([1, 256])
group_scores
=
(
group_scores
=
(
scores_for_choice
.
view
(
-
1
,
self
.
n_group
,
self
.
n_routed_experts
//
self
.
n_group
)
scores_for_choice
.
view
(
-
1
,
self
.
n_group
,
self
.
n_routed_experts
//
self
.
n_group
)
.
topk
(
2
,
dim
=-
1
)[
0
]
.
topk
(
2
,
dim
=-
1
)[
0
]
.
sum
(
dim
=-
1
)
.
sum
(
dim
=-
1
)
)
)
group_idx
=
torch
.
topk
(
group_scores
,
k
=
self
.
topk_group
,
dim
=-
1
,
sorted
=
True
)[
1
]
# Size([1, 4])
group_idx
=
torch
.
topk
(
group_scores
,
k
=
self
.
topk_group
,
dim
=-
1
,
sorted
=
True
)[
1
]
# Size([1, 4])
group_mask
=
torch
.
zeros_like
(
group_scores
)
# Size([1, 8])
group_mask
=
torch
.
zeros_like
(
group_scores
)
# Size([1, 8])
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# Size([1, 8])
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# Size([1, 8])
...
@@ -98,8 +123,12 @@ class DeepseekV3TopkRouter(nn.Module):
...
@@ -98,8 +123,12 @@ class DeepseekV3TopkRouter(nn.Module):
.
reshape
(
-
1
,
self
.
n_routed_experts
)
.
reshape
(
-
1
,
self
.
n_routed_experts
)
)
)
scores_for_choice
=
scores_for_choice
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# Size([1, 256])
scores_for_choice
=
scores_for_choice
.
masked_fill
(
topk_indices
=
torch
.
topk
(
scores_for_choice
,
k
=
self
.
top_k
,
dim
=-
1
,
sorted
=
True
)[
1
]
# Size([1, 8])
~
score_mask
.
bool
(),
0.0
)
# Size([1, 256])
topk_indices
=
torch
.
topk
(
scores_for_choice
,
k
=
self
.
top_k
,
dim
=-
1
,
sorted
=
True
)[
1
]
# Size([1, 8])
return
topk_indices
return
topk_indices
...
@@ -122,7 +151,9 @@ class DeepseekV3TopkRouter(nn.Module):
...
@@ -122,7 +151,9 @@ class DeepseekV3TopkRouter(nn.Module):
def
torch_topkrouter
(
router_logits
,
correction_bias
,
routed_scaling_factor
,
topk
):
def
torch_topkrouter
(
router_logits
,
correction_bias
,
routed_scaling_factor
,
topk
):
lable_indices
,
lable_values
=
DeepseekV3TopkRouter
(
correction_bias
,
routed_scaling_factor
,
topk
)(
router_logits
)
lable_indices
,
lable_values
=
DeepseekV3TopkRouter
(
correction_bias
,
routed_scaling_factor
,
topk
)(
router_logits
)
lable_indices
=
lable_indices
.
to
(
torch
.
int32
)
lable_indices
=
lable_indices
.
to
(
torch
.
int32
)
return
lable_values
,
lable_indices
return
lable_values
,
lable_indices
...
@@ -146,8 +177,12 @@ def test(
...
@@ -146,8 +177,12 @@ def test(
data
=
torch
.
arange
(
0
,
x_shape
[
0
]
*
x_shape
[
1
]).
reshape
(
x_shape
)
data
=
torch
.
arange
(
0
,
x_shape
[
0
]
*
x_shape
[
1
]).
reshape
(
x_shape
)
N
,
width
=
x_shape
N
,
width
=
x_shape
x
=
TestTensor
(
x_shape
,
data
.
stride
(),
x_dtype
,
device
,
scale
=
5.0
,
bias
=-
5.0
,
mode
=
"random"
)
x
=
TestTensor
(
correction_bias
=
TestTensor
([
x_shape
[
1
]],
[
1
],
InfiniDtype
.
F32
,
device
,
mode
=
"random"
)
x_shape
,
data
.
stride
(),
x_dtype
,
device
,
scale
=
5.0
,
bias
=-
5.0
,
mode
=
"random"
)
correction_bias
=
TestTensor
(
[
x_shape
[
1
]],
[
1
],
InfiniDtype
.
F32
,
device
,
mode
=
"random"
)
if
sync
is
not
None
:
if
sync
is
not
None
:
sync
()
sync
()
...
@@ -155,10 +190,7 @@ def test(
...
@@ -155,10 +190,7 @@ def test(
descriptor
=
infiniopOperatorDescriptor_t
()
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
check_error
(
LIBINFINIOP
.
infiniopCreateTopkrouterDescriptor
(
LIBINFINIOP
.
infiniopCreateTopkrouterDescriptor
(
handle
,
handle
,
ctypes
.
byref
(
descriptor
),
x
.
descriptor
,
correction_bias
.
descriptor
ctypes
.
byref
(
descriptor
),
x
.
descriptor
,
correction_bias
.
descriptor
)
)
)
)
...
@@ -174,8 +206,12 @@ def test(
...
@@ -174,8 +206,12 @@ def test(
)
)
workspace
=
TestWorkspace
(
workspace_size
.
value
,
x
.
device
)
workspace
=
TestWorkspace
(
workspace_size
.
value
,
x
.
device
)
values
=
torch
.
zeros
((
N
,
topk
),
dtype
=
torch
.
float32
,
device
=
torch_device_map
[
x
.
device
])
values
=
torch
.
zeros
(
indices
=
torch
.
zeros
((
N
,
topk
),
dtype
=
torch
.
int32
,
device
=
torch_device_map
[
x
.
device
])
(
N
,
topk
),
dtype
=
torch
.
float32
,
device
=
torch_device_map
[
x
.
device
]
)
indices
=
torch
.
zeros
(
(
N
,
topk
),
dtype
=
torch
.
int32
,
device
=
torch_device_map
[
x
.
device
]
)
def
lib_topkrouter
():
def
lib_topkrouter
():
check_error
(
check_error
(
...
@@ -195,8 +231,9 @@ def test(
...
@@ -195,8 +231,9 @@ def test(
lib_topkrouter
()
lib_topkrouter
()
lable_values
,
lable_indices
=
torch_topkrouter
(
lable_values
,
lable_indices
=
torch_topkrouter
(
x
.
actual_tensor
(),
correction_bias
.
actual_tensor
(),
routed_scaling_factor
,
topk
)
x
.
actual_tensor
(),
correction_bias
.
actual_tensor
(),
routed_scaling_factor
,
topk
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
if
DEBUG
:
debug
(
lable_values
,
values
,
atol
=
atol
,
rtol
=
rtol
)
debug
(
lable_values
,
values
,
atol
=
atol
,
rtol
=
rtol
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment