Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4f42c8cd
Unverified
Commit
4f42c8cd
authored
Oct 07, 2025
by
Yuan Luo
Committed by
GitHub
Oct 07, 2025
Browse files
[sgl-kernel] Support float64 moe_sum_reduce cuda kernel (#11068)
Co-authored-by:
luoyuan.luo
<
luoyuan.luo@antgroup.com
>
parent
3ddd7dc9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
224 additions
and
90 deletions
+224
-90
sgl-kernel/benchmark/bench_sum_scale.py
sgl-kernel/benchmark/bench_sum_scale.py
+53
-19
sgl-kernel/csrc/moe/moe_sum_reduce.cu
sgl-kernel/csrc/moe/moe_sum_reduce.cu
+171
-71
No files found.
benchmark/kernels/fused_moe_triton
/benchmark_sum_scale.py
→
sgl-kernel
/benchmark
/bench
_sum_scale.py
View file @
4f42c8cd
import
os
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sgl_kernel
import
moe_sum_reduce
as
moe_sum_reduce_cuda
from
sgl_kernel
import
moe_sum_reduce
as
moe_sum_reduce_cuda
from
triton.testing
import
do_bench
from
triton.testing
import
do_bench
# CI environment detection
IS_CI
=
(
os
.
getenv
(
"CI"
,
"false"
).
lower
()
==
"true"
or
os
.
getenv
(
"GITHUB_ACTIONS"
,
"false"
).
lower
()
==
"true"
)
@
triton
.
jit
@
triton
.
jit
def
_moe_sum_reduce_kernel
(
def
_moe_sum_reduce_kernel
(
...
@@ -38,7 +46,6 @@ def _moe_sum_reduce_kernel(
...
@@ -38,7 +46,6 @@ def _moe_sum_reduce_kernel(
base_ptrs
=
input_ptr
+
offs_token
[:,
None
]
*
input_stride_0
+
offs_dim
[
None
,
:]
base_ptrs
=
input_ptr
+
offs_token
[:,
None
]
*
input_stride_0
+
offs_dim
[
None
,
:]
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_DIM
),
dtype
=
tl
.
float32
)
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_DIM
),
dtype
=
tl
.
float32
)
for
i
in
tl
.
range
(
0
,
topk_num
,
num_stages
=
NUM_STAGE
):
for
i
in
tl
.
range
(
0
,
topk_num
,
num_stages
=
NUM_STAGE
):
tile
=
tl
.
load
(
tile
=
tl
.
load
(
base_ptrs
+
i
*
input_stride_1
,
base_ptrs
+
i
*
input_stride_1
,
...
@@ -110,7 +117,7 @@ def compute_sum_scaled_compiled(
...
@@ -110,7 +117,7 @@ def compute_sum_scaled_compiled(
return
out
return
out
def
get_benchmark
():
def
get_benchmark
(
dtype
=
torch
.
bfloat16
):
num_tokens_range
=
[
2
**
i
for
i
in
range
(
0
,
13
)]
num_tokens_range
=
[
2
**
i
for
i
in
range
(
0
,
13
)]
@
triton
.
testing
.
perf_report
(
@
triton
.
testing
.
perf_report
(
...
@@ -122,7 +129,7 @@ def get_benchmark():
...
@@ -122,7 +129,7 @@ def get_benchmark():
line_names
=
[
"Original"
,
"TorchCompile"
,
"TritonKernel"
,
"CudaKernel"
],
line_names
=
[
"Original"
,
"TorchCompile"
,
"TritonKernel"
,
"CudaKernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
),
(
"yellow"
,
"-"
)],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
),
(
"yellow"
,
"-"
)],
ylabel
=
"us"
,
ylabel
=
"us"
,
plot_name
=
"sum_scaled_performance"
,
plot_name
=
f
"sum_scaled_performance
_
{
str
(
dtype
).
split
(
'.'
)[
-
1
]
}
"
,
args
=
{},
args
=
{},
)
)
)
)
...
@@ -174,8 +181,8 @@ def get_benchmark():
...
@@ -174,8 +181,8 @@ def get_benchmark():
return
benchmark
return
benchmark
def
verify_correctness
(
num_tokens
=
1024
):
def
verify_correctness
(
num_tokens
=
1024
,
dtype
=
torch
.
bfloat16
):
x
=
torch
.
randn
(
num_tokens
,
9
,
4096
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x
=
torch
.
randn
(
num_tokens
,
9
,
4096
,
device
=
"cuda"
,
dtype
=
dtype
)
scaling_factor
=
0.3
scaling_factor
=
0.3
out_baseline
=
torch
.
empty_like
(
x
[:,
0
])
out_baseline
=
torch
.
empty_like
(
x
[:,
0
])
...
@@ -184,33 +191,60 @@ def verify_correctness(num_tokens=1024):
...
@@ -184,33 +191,60 @@ def verify_correctness(num_tokens=1024):
out_compiled
=
torch
.
empty_like
(
out_baseline
)
out_compiled
=
torch
.
empty_like
(
out_baseline
)
compute_sum_scaled_compiled
(
x
,
out_compiled
,
scaling_factor
)
compute_sum_scaled_compiled
(
x
,
out_compiled
,
scaling_factor
)
out_triton
=
torch
.
empty_like
(
out_baseline
)
moe_sum_reduce_triton
(
x
,
out_triton
,
scaling_factor
)
out_cuda
=
torch
.
empty_like
(
out_baseline
)
out_cuda
=
torch
.
empty_like
(
out_baseline
)
moe_sum_reduce_cuda
(
x
,
out_cuda
,
scaling_factor
)
moe_sum_reduce_cuda
(
x
,
out_cuda
,
scaling_factor
)
if
(
triton_skipped
=
dtype
==
torch
.
float64
torch
.
allclose
(
out_baseline
,
out_compiled
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
not
triton_skipped
:
and
torch
.
allclose
(
out_baseline
,
out_triton
,
atol
=
1e-2
,
rtol
=
1e-2
)
out_triton
=
torch
.
empty_like
(
out_baseline
)
and
torch
.
allclose
(
out_baseline
,
out_cuda
,
atol
=
1e-2
,
rtol
=
1e-2
)
moe_sum_reduce_triton
(
x
,
out_triton
,
scaling_factor
)
):
print
(
"✅ All implementations match"
)
if
dtype
==
torch
.
float64
:
atol
,
rtol
=
1e-12
,
1e-12
elif
dtype
==
torch
.
float32
:
atol
,
rtol
=
1e-6
,
1e-6
else
:
# bfloat16 / float16
atol
,
rtol
=
1e-2
,
1e-2
ok_compiled
=
torch
.
allclose
(
out_baseline
,
out_compiled
,
atol
=
atol
,
rtol
=
rtol
)
ok_cuda
=
torch
.
allclose
(
out_baseline
,
out_cuda
,
atol
=
atol
,
rtol
=
rtol
)
ok_triton
=
(
True
if
triton_skipped
else
torch
.
allclose
(
out_baseline
,
out_triton
,
atol
=
atol
,
rtol
=
rtol
)
)
if
ok_compiled
and
ok_triton
and
ok_cuda
:
msg
=
"✅ All implementations match"
if
triton_skipped
:
msg
+=
" (Triton skipped for float64)"
print
(
msg
)
else
:
else
:
print
(
"❌ Implementations differ"
)
print
(
"❌ Implementations differ"
)
print
(
print
(
f
"Baseline vs Compiled:
{
(
out_baseline
-
out_compiled
).
abs
().
max
().
item
()
}
"
f
"Baseline vs Compiled:
{
(
out_baseline
-
out_compiled
).
abs
().
max
().
item
()
}
"
)
)
print
(
f
"Baseline vs Triton:
{
(
out_baseline
-
out_triton
).
abs
().
max
().
item
()
}
"
)
if
not
triton_skipped
:
print
(
f
"Baseline vs Triton:
{
(
out_baseline
-
out_triton
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Baseline vs Cuda:
{
(
out_baseline
-
out_cuda
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Baseline vs Cuda:
{
(
out_baseline
-
out_cuda
).
abs
().
max
().
item
()
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
print
(
"Running correctness verification..."
)
print
(
"Running correctness verification for bfloat16..."
)
verify_correctness
()
verify_correctness
(
dtype
=
torch
.
bfloat16
)
# CI environment uses simplified parameters
if
not
IS_CI
:
print
(
"Running correctness verification for float64..."
)
verify_correctness
(
dtype
=
torch
.
float64
)
print
(
"Running correctness verification for float64..."
)
verify_correctness
(
dtype
=
torch
.
float64
)
print
(
"
\n
Running performance benchmark..."
)
print
(
"
\n
Running performance benchmark
for bfloat16
..."
)
benchmark
=
get_benchmark
()
benchmark
=
get_benchmark
(
dtype
=
torch
.
bfloat16
)
benchmark
.
run
(
benchmark
.
run
(
print_data
=
True
,
print_data
=
True
,
# save_path="./configs/benchmark_ops/sum_scaled/"
# save_path="./configs/benchmark_ops/sum_scaled/"
...
...
sgl-kernel/csrc/moe/moe_sum_reduce.cu
View file @
4f42c8cd
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda.h>
...
@@ -12,25 +13,36 @@
...
@@ -12,25 +13,36 @@
#include "utils.h"
#include "utils.h"
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
float
to_float
(
T
x
)
{
using
opmath_t
=
at
::
opmath_type
<
T
>
;
return
static_cast
<
float
>
(
x
);
}
template
<
>
template
<
typename
T
>
__device__
__forceinline__
float
to_float
<
half
>
(
half
x
)
{
__device__
__forceinline__
opmath_t
<
T
>
to_acc
(
T
x
)
{
return
__half2float
(
x
);
return
static_cast
<
opmath_t
<
T
>>
(
x
);
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
T
from_
float
(
float
x
)
{
__device__
__forceinline__
T
from_
acc
(
opmath_t
<
T
>
x
)
{
return
static_cast
<
T
>
(
x
);
return
static_cast
<
T
>
(
x
);
}
}
template
<
>
template
<
>
__device__
__forceinline__
half
from_float
<
half
>
(
float
x
)
{
__device__
__forceinline__
opmath_t
<
at
::
Half
>
to_acc
<
at
::
Half
>
(
at
::
Half
x
)
{
return
__half2float
(
__nv_half
(
x
));
}
template
<
>
__device__
__forceinline__
at
::
Half
from_acc
<
at
::
Half
>
(
opmath_t
<
at
::
Half
>
x
)
{
return
__float2half_rn
(
x
);
return
__float2half_rn
(
x
);
}
}
template
<
>
__device__
__forceinline__
opmath_t
<
at
::
BFloat16
>
to_acc
<
at
::
BFloat16
>
(
at
::
BFloat16
x
)
{
return
__bfloat162float
(
__nv_bfloat16
(
x
));
}
template
<
>
__device__
__forceinline__
at
::
BFloat16
from_acc
<
at
::
BFloat16
>
(
opmath_t
<
at
::
BFloat16
>
x
)
{
return
__float2bfloat16_rn
(
x
);
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
T
ldg_cg
(
const
T
*
p
)
{
__device__
__forceinline__
T
ldg_cg
(
const
T
*
p
)
{
return
__ldg
(
p
);
return
__ldg
(
p
);
...
@@ -111,22 +123,22 @@ __global__ void moe_sum_reduce_kernel_warp_token_topk(
...
@@ -111,22 +123,22 @@ __global__ void moe_sum_reduce_kernel_warp_token_topk(
const
int64_t
stride_token
,
const
int64_t
stride_token
,
const
int64_t
stride_topk
,
const
int64_t
stride_topk
,
const
int64_t
out_stride_token
,
const
int64_t
out_stride_token
,
const
float
scale
)
{
const
opmath_t
<
scalar_t
>
scale
)
{
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
lane
=
threadIdx
.
x
%
32
;
const
int
lane
=
threadIdx
.
x
%
32
;
const
int64_t
t
=
(
int64_t
)
blockIdx
.
y
*
WARPS_PER_BLOCK
+
warp_id
;
const
int64_t
t
=
(
int64_t
)
blockIdx
.
y
*
WARPS_PER_BLOCK
+
warp_id
;
if
(
t
>=
token_num
)
return
;
if
(
t
>=
token_num
)
return
;
for
(
int64_t
d
=
(
int64_t
)
blockIdx
.
x
*
32
+
lane
;
d
<
hidden_dim
;
d
+=
(
int64_t
)
gridDim
.
x
*
32
)
{
for
(
int64_t
d
=
(
int64_t
)
blockIdx
.
x
*
32
+
lane
;
d
<
hidden_dim
;
d
+=
(
int64_t
)
gridDim
.
x
*
32
)
{
float
acc
=
0.
f
;
opmath_t
<
scalar_t
>
acc
=
opmath_t
<
scalar_t
>
(
0
)
;
const
int64_t
base
=
t
*
stride_token
+
d
;
const
int64_t
base
=
t
*
stride_token
+
d
;
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
TOPK
;
++
k
)
{
for
(
int
k
=
0
;
k
<
TOPK
;
++
k
)
{
acc
+=
to_
float
<
scalar_t
>
(
ldg_cg
(
&
x
[
base
+
(
int64_t
)
k
*
stride_topk
])
)
;
acc
+=
to_
acc
<
scalar_t
>
(
x
[
base
+
(
int64_t
)
k
*
stride_topk
]);
}
}
acc
*=
scale
;
acc
*=
scale
;
y
[
t
*
out_stride_token
+
d
]
=
from_
float
<
scalar_t
>
(
acc
);
y
[
t
*
out_stride_token
+
d
]
=
from_
acc
<
scalar_t
>
(
acc
);
}
}
}
}
...
@@ -139,23 +151,79 @@ __global__ void moe_sum_reduce_kernel(
...
@@ -139,23 +151,79 @@ __global__ void moe_sum_reduce_kernel(
const
int64_t
stride_token
,
const
int64_t
stride_token
,
const
int64_t
stride_topk
,
const
int64_t
stride_topk
,
const
int64_t
out_stride_token
,
const
int64_t
out_stride_token
,
const
float
scale
)
{
const
opmath_t
<
scalar_t
>
scale
)
{
for
(
int
t
=
blockIdx
.
y
;
t
<
token_num
;
t
+=
gridDim
.
y
)
{
for
(
int
t
=
blockIdx
.
y
;
t
<
token_num
;
t
+=
gridDim
.
y
)
{
for
(
int
d
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
d
<
hidden_dim
;
d
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
int
d
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
d
<
hidden_dim
;
d
+=
blockDim
.
x
*
gridDim
.
x
)
{
const
int64_t
base
=
t
*
stride_token
+
d
;
const
int64_t
base
=
t
*
stride_token
+
d
;
float
acc
=
0.
f
;
opmath_t
<
scalar_t
>
acc
=
opmath_t
<
scalar_t
>
(
0
)
;
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
TOPK
;
++
k
)
{
for
(
int
k
=
0
;
k
<
TOPK
;
++
k
)
{
acc
+=
to_
float
<
scalar_t
>
(
x
[
base
+
(
int64_t
)
k
*
stride_topk
]);
acc
+=
to_
acc
<
scalar_t
>
(
x
[
base
+
(
int64_t
)
k
*
stride_topk
]);
}
}
acc
*=
scale
;
acc
*=
scale
;
y
[
t
*
out_stride_token
+
d
]
=
from_
float
<
scalar_t
>
(
acc
);
y
[
t
*
out_stride_token
+
d
]
=
from_
acc
<
scalar_t
>
(
acc
);
}
}
}
}
}
}
// -------------------- general-topk fallback kernels --------------------
// small-token
template
<
typename
scalar_t
>
__global__
void
moe_sum_reduce_kernel_general
(
const
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
y
,
const
int64_t
token_num
,
const
int64_t
hidden_dim
,
const
int64_t
stride_token
,
const
int64_t
stride_topk
,
const
int64_t
out_stride_token
,
const
int
topk_num
,
const
opmath_t
<
scalar_t
>
scale
)
{
for
(
int
t
=
blockIdx
.
y
;
t
<
token_num
;
t
+=
gridDim
.
y
)
{
for
(
int
d
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
d
<
hidden_dim
;
d
+=
blockDim
.
x
*
gridDim
.
x
)
{
const
int64_t
base
=
t
*
stride_token
+
d
;
opmath_t
<
scalar_t
>
acc
=
opmath_t
<
scalar_t
>
(
0
);
#pragma unroll 1
for
(
int
k
=
0
;
k
<
topk_num
;
++
k
)
{
acc
+=
to_acc
<
scalar_t
>
(
x
[
base
+
(
int64_t
)
k
*
stride_topk
]);
}
acc
*=
scale
;
y
[
t
*
out_stride_token
+
d
]
=
from_acc
<
scalar_t
>
(
acc
);
}
}
}
// warp-per-token
template
<
typename
scalar_t
,
int
WARPS_PER_BLOCK
>
__global__
void
moe_sum_reduce_kernel_warp_token_general
(
const
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
y
,
const
int64_t
token_num
,
const
int64_t
hidden_dim
,
const
int64_t
stride_token
,
const
int64_t
stride_topk
,
const
int64_t
out_stride_token
,
const
int
topk_num
,
const
opmath_t
<
scalar_t
>
scale
)
{
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
lane
=
threadIdx
.
x
%
32
;
const
int64_t
t
=
(
int64_t
)
blockIdx
.
y
*
WARPS_PER_BLOCK
+
warp_id
;
if
(
t
>=
token_num
)
return
;
for
(
int64_t
d
=
(
int64_t
)
blockIdx
.
x
*
32
+
lane
;
d
<
hidden_dim
;
d
+=
(
int64_t
)
gridDim
.
x
*
32
)
{
opmath_t
<
scalar_t
>
acc
=
opmath_t
<
scalar_t
>
(
0
);
const
int64_t
base
=
t
*
stride_token
+
d
;
#pragma unroll 1
for
(
int
k
=
0
;
k
<
topk_num
;
++
k
)
{
acc
+=
to_acc
<
scalar_t
>
(
x
[
base
+
(
int64_t
)
k
*
stride_topk
]);
}
acc
*=
scale
;
y
[
t
*
out_stride_token
+
d
]
=
from_acc
<
scalar_t
>
(
acc
);
}
}
void
moe_sum_reduce
(
at
::
Tensor
&
input
,
at
::
Tensor
&
output
,
double
routed_scaling_factor
)
{
void
moe_sum_reduce
(
at
::
Tensor
&
input
,
at
::
Tensor
&
output
,
double
routed_scaling_factor
)
{
TORCH_CHECK
(
input
.
is_cuda
(),
"input must be CUDA tensor"
);
TORCH_CHECK
(
input
.
is_cuda
(),
"input must be CUDA tensor"
);
TORCH_CHECK
(
output
.
is_cuda
(),
"output must be CUDA tensor"
);
TORCH_CHECK
(
output
.
is_cuda
(),
"output must be CUDA tensor"
);
...
@@ -175,8 +243,6 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
...
@@ -175,8 +243,6 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
const
int64_t
in_stride_topk
=
input
.
stride
(
1
);
const
int64_t
in_stride_topk
=
input
.
stride
(
1
);
const
int64_t
out_stride_token
=
output
.
stride
(
0
);
const
int64_t
out_stride_token
=
output
.
stride
(
0
);
const
float
scale
=
static_cast
<
float
>
(
routed_scaling_factor
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
bool
fast_bf16_vec_ok
=
(
input
.
scalar_type
()
==
at
::
kBFloat16
)
&&
(
token_num
>
256
)
&&
(
hidden_dim
%
8
==
0
);
const
bool
fast_bf16_vec_ok
=
(
input
.
scalar_type
()
==
at
::
kBFloat16
)
&&
(
token_num
>
256
)
&&
(
hidden_dim
%
8
==
0
);
...
@@ -198,6 +264,7 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
...
@@ -198,6 +264,7 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
float
scale
=
static_cast
<
float
>
(
routed_scaling_factor
);
moe_sum_reduce_warp_per_token_vec_kernel
<
WARPS_PER_BLOCK
><<<
grid
,
block
,
0
,
stream
>>>
(
moe_sum_reduce_warp_per_token_vec_kernel
<
WARPS_PER_BLOCK
><<<
grid
,
block
,
0
,
stream
>>>
(
reinterpret_cast
<
const
at
::
BFloat16
*>
(
input
.
data_ptr
<
at
::
BFloat16
>
()),
reinterpret_cast
<
const
at
::
BFloat16
*>
(
input
.
data_ptr
<
at
::
BFloat16
>
()),
reinterpret_cast
<
at
::
BFloat16
*>
(
output
.
data_ptr
<
at
::
BFloat16
>
()),
reinterpret_cast
<
at
::
BFloat16
*>
(
output
.
data_ptr
<
at
::
BFloat16
>
()),
...
@@ -209,32 +276,12 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
...
@@ -209,32 +276,12 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
out_stride_token
,
out_stride_token
,
scale
);
scale
);
TORCH_CHECK
(
cudaGetLastError
()
==
cudaSuccess
,
"moe_sum_reduce CUDA kernel launch failed"
);
TORCH_CHECK
(
cudaGetLastError
()
==
cudaSuccess
,
"moe_sum_reduce CUDA kernel
(bf16 vec)
launch failed"
);
return
;
return
;
}
}
const
bool
per_token_use_one_warp
=
(
token_num
>
128
);
const
bool
per_token_use_one_warp
=
(
token_num
>
128
);
auto
dispatch_topk
=
[
&
](
auto
&&
launch_kernel
)
{
switch
(
topk_num
)
{
case
2
:
launch_kernel
(
std
::
integral_constant
<
int
,
2
>
{});
break
;
case
4
:
launch_kernel
(
std
::
integral_constant
<
int
,
4
>
{});
break
;
case
8
:
launch_kernel
(
std
::
integral_constant
<
int
,
8
>
{});
break
;
case
9
:
launch_kernel
(
std
::
integral_constant
<
int
,
9
>
{});
break
;
default:
launch_kernel
(
std
::
integral_constant
<
int
,
-
1
>
{});
break
;
}
};
if
(
!
per_token_use_one_warp
)
{
if
(
!
per_token_use_one_warp
)
{
// ---------- small-token ----------
// ---------- small-token ----------
const
int
block_size
=
256
;
const
int
block_size
=
256
;
...
@@ -245,28 +292,55 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
...
@@ -245,28 +292,55 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
dim3
block
(
block_size
);
dim3
block
(
block_size
);
dim3
grid
(
static_cast
<
unsigned
>
(
grid_x
),
static_cast
<
unsigned
>
(
grid_y
));
dim3
grid
(
static_cast
<
unsigned
>
(
grid_x
),
static_cast
<
unsigned
>
(
grid_y
));
#define LAUNCH_SMALL_TOKEN_KERNEL(TOPK) \
moe_sum_reduce_kernel<scalar_t_, TOPK><<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t_>(), \
output.data_ptr<scalar_t_>(), \
token_num, \
hidden_dim, \
in_stride_token, \
in_stride_topk, \
out_stride_token, \
scale);
AT_DISPATCH_FLOATING_TYPES_AND2
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
kHalf
,
at
::
kBFloat16
,
input
.
scalar_type
(),
"moe_sum_reduce_cuda_small_token"
,
[
&
]
{
at
::
kHalf
,
at
::
kBFloat16
,
input
.
scalar_type
(),
"moe_sum_reduce_cuda_small_token"
,
[
&
]
{
using
scalar_t_
=
scalar_t
;
using
scalar_t_
=
scalar_t
;
using
acc_t_
=
opmath_t
<
scalar_t_
>
;
auto
lauch_small_token_kernel
=
[
&
](
auto
topk_c
)
{
const
acc_t_
scale
=
static_cast
<
acc_t_
>
(
routed_scaling_factor
);
constexpr
int
TK
=
decltype
(
topk_c
)
::
value
;
switch
(
topk_num
)
{
moe_sum_reduce_kernel
<
scalar_t_
,
TK
><<<
grid
,
block
,
0
,
stream
>>>
(
case
2
:
input
.
data_ptr
<
scalar_t_
>
(),
LAUNCH_SMALL_TOKEN_KERNEL
(
2
);
output
.
data_ptr
<
scalar_t_
>
(),
break
;
token_num
,
case
4
:
hidden_dim
,
LAUNCH_SMALL_TOKEN_KERNEL
(
4
);
in_stride_token
,
break
;
in_stride_topk
,
case
8
:
out_stride_token
,
LAUNCH_SMALL_TOKEN_KERNEL
(
8
);
scale
);
break
;
};
case
9
:
dispatch_topk
(
lauch_small_token_kernel
);
LAUNCH_SMALL_TOKEN_KERNEL
(
9
);
break
;
default:
// launch general kernel
moe_sum_reduce_kernel_general
<
scalar_t_
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t_
>
(),
output
.
data_ptr
<
scalar_t_
>
(),
token_num
,
hidden_dim
,
in_stride_token
,
in_stride_topk
,
out_stride_token
,
static_cast
<
int
>
(
topk_num
),
scale
);
}
});
});
#undef LAUNCH_SMALL_TOKEN_KERNEL
TORCH_CHECK
(
cudaGetLastError
()
==
cudaSuccess
,
"moe_sum_reduce CUDA kernel (small-token) launch failed"
);
}
else
{
}
else
{
// ---------- warp-token ----------
// ---------- warp-
per-
token ----------
constexpr
int
WARPS_PER_BLOCK
=
4
;
constexpr
int
WARPS_PER_BLOCK
=
4
;
constexpr
int
THREADS
=
WARPS_PER_BLOCK
*
32
;
constexpr
int
THREADS
=
WARPS_PER_BLOCK
*
32
;
...
@@ -279,25 +353,51 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
...
@@ -279,25 +353,51 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling
dim3
block
(
THREADS
);
dim3
block
(
THREADS
);
dim3
grid
(
static_cast
<
unsigned
>
(
gx
),
static_cast
<
unsigned
>
(
gy
));
dim3
grid
(
static_cast
<
unsigned
>
(
gx
),
static_cast
<
unsigned
>
(
gy
));
#define LAUNCH_WARP_PER_TOKEN_KERNEL(TOPK) \
moe_sum_reduce_kernel_warp_token_topk<scalar_t_, TOPK, WARPS_PER_BLOCK><<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t_>(), \
output.data_ptr<scalar_t_>(), \
token_num, \
hidden_dim, \
in_stride_token, \
in_stride_topk, \
out_stride_token, \
scale);
AT_DISPATCH_FLOATING_TYPES_AND2
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
kHalf
,
at
::
kBFloat16
,
input
.
scalar_type
(),
"moe_sum_reduce_cuda_large_token"
,
[
&
]
{
at
::
kHalf
,
at
::
kBFloat16
,
input
.
scalar_type
(),
"moe_sum_reduce_cuda_large_token"
,
[
&
]
{
using
scalar_t_
=
scalar_t
;
using
scalar_t_
=
scalar_t
;
using
acc_t_
=
opmath_t
<
scalar_t_
>
;
auto
launch_large_token_kernel
=
[
&
](
auto
topk_c
)
{
const
acc_t_
scale
=
static_cast
<
acc_t_
>
(
routed_scaling_factor
);
constexpr
int
TK
=
decltype
(
topk_c
)
::
value
;
switch
(
topk_num
)
{
moe_sum_reduce_kernel_warp_token_topk
<
scalar_t_
,
TK
,
WARPS_PER_BLOCK
><<<
grid
,
block
,
0
,
stream
>>>
(
case
2
:
input
.
data_ptr
<
scalar_t_
>
(),
LAUNCH_WARP_PER_TOKEN_KERNEL
(
2
);
output
.
data_ptr
<
scalar_t_
>
(),
break
;
token_num
,
case
4
:
hidden_dim
,
LAUNCH_WARP_PER_TOKEN_KERNEL
(
4
);
in_stride_token
,
break
;
in_stride_topk
,
case
8
:
out_stride_token
,
LAUNCH_WARP_PER_TOKEN_KERNEL
(
8
);
scale
);
break
;
};
case
9
:
dispatch_topk
(
launch_large_token_kernel
);
LAUNCH_WARP_PER_TOKEN_KERNEL
(
9
);
break
;
default:
// launch general kernel
moe_sum_reduce_kernel_warp_token_general
<
scalar_t_
,
WARPS_PER_BLOCK
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t_
>
(),
output
.
data_ptr
<
scalar_t_
>
(),
token_num
,
hidden_dim
,
in_stride_token
,
in_stride_topk
,
out_stride_token
,
static_cast
<
int
>
(
topk_num
),
scale
);
}
});
});
#undef LAUNCH_WARP_PER_TOKEN_KERNEL
TORCH_CHECK
(
cudaGetLastError
()
==
cudaSuccess
,
"moe_sum_reduce CUDA kernel (warp-token) launch failed"
);
}
}
TORCH_CHECK
(
cudaGetLastError
()
==
cudaSuccess
,
"CUDA kernel launch failed"
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment