Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
57ab7769
"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "6e9ce183232648858d12f6c8f4061c0e83af92d3"
Unverified
Commit
57ab7769
authored
Jun 25, 2025
by
Ke Bao
Committed by
GitHub
Jun 24, 2025
Browse files
Fuse sorted_token_ids padding to moe_align_block_size kernel (#7437)
parent
112b496a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
160 additions
and
67 deletions
+160
-67
sgl-kernel/benchmark/bench_moe_align_block_size.py
sgl-kernel/benchmark/bench_moe_align_block_size.py
+53
-48
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+2
-1
sgl-kernel/csrc/moe/moe_align_kernel.cu
sgl-kernel/csrc/moe/moe_align_kernel.cu
+57
-5
sgl-kernel/csrc/torch_extension_rocm.cc
sgl-kernel/csrc/torch_extension_rocm.cc
+2
-1
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+2
-1
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+2
-0
sgl-kernel/tests/test_moe_align.py
sgl-kernel/tests/test_moe_align.py
+42
-11
No files found.
sgl-kernel/benchmark/bench_moe_align_block_size.py
View file @
57ab7769
...
@@ -5,7 +5,11 @@ import torch
...
@@ -5,7 +5,11 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
from
vllm
import
_custom_ops
as
ops
try
:
from
vllm
import
_custom_ops
as
ops
except
ImportError
:
ops
=
None
USE_RANDOM_PERM
=
False
USE_RANDOM_PERM
=
False
...
@@ -208,7 +212,7 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
...
@@ -208,7 +212,7 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
)
)
print
(
f
"✅ VLLM implementation works with
{
num_experts
}
experts!"
)
print
(
f
"✅ VLLM implementation works with
{
num_experts
}
experts!"
)
vllm_works
=
True
vllm_works
=
True
except
RuntimeError
as
e
:
except
Exception
as
e
:
print
(
f
"❌ VLLM implementation failed with
{
num_experts
}
experts:
{
e
}
"
)
print
(
f
"❌ VLLM implementation failed with
{
num_experts
}
experts:
{
e
}
"
)
vllm_works
=
False
vllm_works
=
False
...
@@ -257,13 +261,47 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
...
@@ -257,13 +261,47 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
return
topk_ids
return
topk_ids
def
sgl_moe_align_block_size_with_empty
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
pad_sorted_token_ids
=
False
,
):
if
not
pad_sorted_token_ids
:
sorted_ids
.
fill_
(
topk_ids
.
numel
())
token_cnts_buffer
=
torch
.
empty
(
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
cumsum_buffer
=
torch
.
empty
(
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sgl_moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
.
clone
(),
expert_ids
.
clone
(),
num_tokens_post_pad
.
clone
(),
token_cnts_buffer
,
cumsum_buffer
,
pad_sorted_token_ids
,
)
@
triton
.
testing
.
perf_report
(
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
],
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
],
x_vals
=
configs
,
x_vals
=
configs
,
line_arg
=
"provider"
,
line_arg
=
"provider"
,
line_vals
=
[
"sgl"
,
"
triton"
,
"vllm
"
],
line_vals
=
[
"sgl"
,
"
sgl_fusion"
,
"triton
"
],
line_names
=
[
"SGL"
,
"
Triton"
,
"VLLM
"
],
line_names
=
[
"SGL"
,
"
SGL Fusion"
,
"Triton
"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
),
(
"green"
,
"-"
)],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
ylabel
=
"us"
,
plot_name
=
"moe-align-block-size-performance"
,
plot_name
=
"moe-align-block-size-performance"
,
...
@@ -288,7 +326,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
...
@@ -288,7 +326,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
sorted_ids
=
torch
.
empty
(
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
expert_ids
=
torch
.
empty
(
expert_ids
=
torch
.
empty
(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
...
@@ -297,35 +334,18 @@ def benchmark(num_tokens, num_experts, topk, provider):
...
@@ -297,35 +334,18 @@ def benchmark(num_tokens, num_experts, topk, provider):
quantiles
=
[
0.5
,
0.2
,
0.8
]
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"sgl"
:
if
provider
==
"sgl"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
def
sgl_moe_align_block_size_with_empty
(
lambda
:
sgl_moe_align_block_size_with_empty
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
):
token_cnts_buffer
=
torch
.
empty
(
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
cumsum_buffer
=
torch
.
empty
(
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sgl_moe_align_block_size
(
topk_ids
,
topk_ids
,
num_experts
,
num_experts
,
block_size
,
block_size
,
sorted_ids
.
clone
()
,
sorted_ids
,
expert_ids
.
clone
()
,
expert_ids
,
num_tokens_post_pad
.
clone
()
,
num_tokens_post_pad
,
token_cnts_buffer
,
)
,
cumsum_buffer
,
quantiles
=
quantiles
,
)
)
elif
provider
==
"sgl_fusion"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
sgl_moe_align_block_size_with_empty
(
lambda
:
sgl_moe_align_block_size_with_empty
(
topk_ids
,
topk_ids
,
...
@@ -334,10 +354,12 @@ def benchmark(num_tokens, num_experts, topk, provider):
...
@@ -334,10 +354,12 @@ def benchmark(num_tokens, num_experts, topk, provider):
sorted_ids
,
sorted_ids
,
expert_ids
,
expert_ids
,
num_tokens_post_pad
,
num_tokens_post_pad
,
pad_sorted_token_ids
=
True
,
),
),
quantiles
=
quantiles
,
quantiles
=
quantiles
,
)
)
elif
provider
==
"triton"
:
elif
provider
==
"triton"
:
sorted_ids
.
fill_
(
topk_ids
.
numel
())
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
moe_align_block_size_triton
(
lambda
:
moe_align_block_size_triton
(
topk_ids
,
topk_ids
,
...
@@ -349,23 +371,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
...
@@ -349,23 +371,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
),
),
quantiles
=
quantiles
,
quantiles
=
quantiles
,
)
)
else
:
# vllm
try
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
.
clone
(),
expert_ids
.
clone
(),
num_tokens_post_pad
.
clone
(),
),
quantiles
=
quantiles
,
)
except
RuntimeError
as
e
:
print
(
f
"❌ VLLM benchmark failed with
{
num_experts
}
experts:
{
e
}
"
)
# Return extreme values to indicate failure in the chart
return
float
(
"inf"
),
float
(
"inf"
),
float
(
"inf"
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
...
...
sgl-kernel/csrc/common_extension.cc
View file @
57ab7769
...
@@ -160,7 +160,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -160,7 +160,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
*/
*/
m
.
def
(
m
.
def
(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"
);
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool "
"pad_sorted_token_ids) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
def
(
m
.
def
(
...
...
sgl-kernel/csrc/moe/moe_align_kernel.cu
View file @
57ab7769
...
@@ -21,8 +21,17 @@ limitations under the License.
...
@@ -21,8 +21,17 @@ limitations under the License.
#include "utils.h"
#include "utils.h"
template
<
typename
T
,
int
N
,
int
Alignment
=
sizeof
(
T
)
*
N
>
class
alignas
(
Alignment
)
AlignedArray
{
public:
T
data
[
N
];
};
#define WARP_SIZE 32
#define WARP_SIZE 32
#define VEC_SIZE 4
using
Vec
=
AlignedArray
<
int32_t
,
VEC_SIZE
>
;
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
count_and_sort_expert_tokens_kernel
(
__global__
void
count_and_sort_expert_tokens_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
const
scalar_t
*
__restrict__
topk_ids
,
...
@@ -50,7 +59,8 @@ __global__ void moe_align_block_size_kernel(
...
@@ -50,7 +59,8 @@ __global__ void moe_align_block_size_kernel(
int32_t
experts_per_warp
,
int32_t
experts_per_warp
,
int32_t
block_size
,
int32_t
block_size
,
size_t
numel
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
)
{
int32_t
*
__restrict__
cumsum
,
bool
pad_sorted_token_ids
)
{
extern
__shared__
int32_t
shared_counts
[];
extern
__shared__
int32_t
shared_counts
[];
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
...
@@ -96,6 +106,24 @@ __global__ void moe_align_block_size_kernel(
...
@@ -96,6 +106,24 @@ __global__ void moe_align_block_size_kernel(
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
}
}
if
(
pad_sorted_token_ids
)
{
int32_t
fill_val
=
static_cast
<
int32_t
>
(
numel
);
int32_t
total
=
*
total_tokens_post_pad
;
Vec
fill_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
fill_vec
.
data
[
i
]
=
fill_val
;
}
int32_t
total_vec_count
=
(
total
+
VEC_SIZE
-
1
)
/
VEC_SIZE
;
Vec
*
out_ptr
=
reinterpret_cast
<
Vec
*>
(
sorted_token_ids
);
for
(
int32_t
idx
=
tid
;
idx
<
total_vec_count
;
idx
+=
stride
)
{
out_ptr
[
idx
]
=
fill_vec
;
}
}
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
@@ -106,7 +134,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
...
@@ -106,7 +134,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
num_experts
,
int32_t
block_size
,
int32_t
block_size
,
size_t
numel
)
{
size_t
numel
,
bool
pad_sorted_token_ids
)
{
const
size_t
tid
=
threadIdx
.
x
;
const
size_t
tid
=
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
;
const
size_t
stride
=
blockDim
.
x
;
...
@@ -149,6 +178,26 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
...
@@ -149,6 +178,26 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
}
}
}
}
if
(
pad_sorted_token_ids
)
{
int32_t
fill_val
=
static_cast
<
int32_t
>
(
numel
);
int32_t
total
=
*
total_tokens_post_pad
;
Vec
fill_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
fill_vec
.
data
[
i
]
=
fill_val
;
}
int32_t
total_vec_count
=
(
total
+
VEC_SIZE
-
1
)
/
VEC_SIZE
;
Vec
*
out_ptr
=
reinterpret_cast
<
Vec
*>
(
sorted_token_ids
);
for
(
int32_t
idx
=
tid
;
idx
<
total_vec_count
;
idx
+=
stride
)
{
out_ptr
[
idx
]
=
fill_vec
;
}
}
__syncthreads
();
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
rank_post_pad
=
tokens_cnts
[
threadIdx
.
x
*
num_experts
+
expert_id
]
+
cumsum
[
expert_id
];
int32_t
rank_post_pad
=
tokens_cnts
[
threadIdx
.
x
*
num_experts
+
expert_id
]
+
cumsum
[
expert_id
];
...
@@ -165,7 +214,8 @@ void moe_align_block_size(
...
@@ -165,7 +214,8 @@ void moe_align_block_size(
torch
::
Tensor
experts_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
)
{
torch
::
Tensor
cumsum_buffer
,
bool
pad_sorted_token_ids
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int64_t
padded_num_experts
=
((
num_experts
+
WARP_SIZE
-
1
)
/
WARP_SIZE
)
*
WARP_SIZE
;
int64_t
padded_num_experts
=
((
num_experts
+
WARP_SIZE
-
1
)
/
WARP_SIZE
)
*
WARP_SIZE
;
...
@@ -190,7 +240,8 @@ void moe_align_block_size(
...
@@ -190,7 +240,8 @@ void moe_align_block_size(
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
num_experts
,
block_size
,
block_size
,
topk_ids
.
numel
());
topk_ids
.
numel
(),
pad_sorted_token_ids
);
}
else
{
}
else
{
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
...
@@ -207,7 +258,8 @@ void moe_align_block_size(
...
@@ -207,7 +258,8 @@ void moe_align_block_size(
experts_per_warp
,
experts_per_warp
,
block_size
,
block_size
,
topk_ids
.
numel
(),
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
pad_sorted_token_ids
);
const
int
block_threads
=
std
::
min
(
256
,
(
int
)
threads
);
const
int
block_threads
=
std
::
min
(
256
,
(
int
)
threads
);
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
...
...
sgl-kernel/csrc/torch_extension_rocm.cc
View file @
57ab7769
...
@@ -59,7 +59,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
...
@@ -59,7 +59,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
*/
*/
m
.
def
(
m
.
def
(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"
);
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool "
"pad_sorted_token_ids) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
def
(
m
.
def
(
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
57ab7769
...
@@ -212,7 +212,8 @@ void moe_align_block_size(
...
@@ -212,7 +212,8 @@ void moe_align_block_size(
torch
::
Tensor
experts_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
);
torch
::
Tensor
cumsum_buffer
,
bool
pad_sorted_token_ids
);
void
topk_softmax
(
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_weights
,
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
57ab7769
...
@@ -12,6 +12,7 @@ def moe_align_block_size(
...
@@ -12,6 +12,7 @@ def moe_align_block_size(
num_tokens_post_pad
,
num_tokens_post_pad
,
token_cnts_buffer
,
token_cnts_buffer
,
cumsum_buffer
,
cumsum_buffer
,
pad_sorted_token_ids
=
False
,
):
):
torch
.
ops
.
sgl_kernel
.
moe_align_block_size
.
default
(
torch
.
ops
.
sgl_kernel
.
moe_align_block_size
.
default
(
topk_ids
,
topk_ids
,
...
@@ -22,6 +23,7 @@ def moe_align_block_size(
...
@@ -22,6 +23,7 @@ def moe_align_block_size(
num_tokens_post_pad
,
num_tokens_post_pad
,
token_cnts_buffer
,
token_cnts_buffer
,
cumsum_buffer
,
cumsum_buffer
,
pad_sorted_token_ids
,
)
)
...
...
sgl-kernel/tests/test_moe_align.py
View file @
57ab7769
...
@@ -138,33 +138,32 @@ def moe_align_block_size_triton(
...
@@ -138,33 +138,32 @@ def moe_align_block_size_triton(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"block_size,num_tokens,topk,num_experts"
,
"block_size,num_tokens,topk,num_experts
,pad_sorted_token_ids
"
,
list
(
list
(
itertools
.
product
(
itertools
.
product
(
[
32
,
64
,
128
,
256
],
# block_size
[
32
,
64
,
128
,
256
],
# block_size
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
],
# num_tokens
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
],
# num_tokens
[
1
,
2
,
4
,
8
,
16
,
32
,
64
],
# topk
[
1
,
2
,
4
,
8
,
16
,
32
,
64
],
# topk
[
64
,
160
,
256
,
257
,
260
,
264
],
# num_experts
[
64
,
160
,
256
,
257
,
260
,
264
],
# num_experts
[
True
,
False
],
# pad_sorted_token_ids
)
)
),
),
)
)
def
test_moe_align_block_size_compare_implementations
(
def
test_moe_align_block_size_compare_implementations
(
block_size
,
num_tokens
,
topk
,
num_experts
block_size
,
num_tokens
,
topk
,
num_experts
,
pad_sorted_token_ids
):
):
topk_ids
=
torch
.
stack
(
topk_ids
=
torch
.
argsort
(
torch
.
rand
(
num_tokens
,
num_experts
,
device
=
"cuda"
),
dim
=
1
)[
[
:,
:
topk
torch
.
randperm
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[:
topk
]
]
for
_
in
range
(
num_tokens
)
]
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids_cuda
=
torch
.
empty
(
sorted_ids_cuda
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
sorted_ids_cuda
.
fill_
(
topk_ids
.
numel
())
if
not
pad_sorted_token_ids
:
sorted_ids_cuda
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
expert_ids_cuda
=
torch
.
zeros
(
expert_ids_cuda
=
torch
.
zeros
(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
...
@@ -195,6 +194,7 @@ def test_moe_align_block_size_compare_implementations(
...
@@ -195,6 +194,7 @@ def test_moe_align_block_size_compare_implementations(
num_tokens_post_pad_cuda
,
num_tokens_post_pad_cuda
,
token_cnts_buffer
,
token_cnts_buffer
,
cumsum_buffer
,
cumsum_buffer
,
pad_sorted_token_ids
,
)
)
moe_align_block_size_triton
(
moe_align_block_size_triton
(
...
@@ -206,20 +206,51 @@ def test_moe_align_block_size_compare_implementations(
...
@@ -206,20 +206,51 @@ def test_moe_align_block_size_compare_implementations(
num_tokens_post_pad_triton
,
num_tokens_post_pad_triton
,
)
)
assert
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_triton
),
(
assert
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_triton
,
atol
=
0
,
rtol
=
0
),
(
f
"Expert IDs mismatch for block_size=
{
block_size
}
, "
f
"Expert IDs mismatch for block_size=
{
block_size
}
, "
f
"num_tokens=
{
num_tokens
}
, topk=
{
topk
}
\n
"
f
"num_tokens=
{
num_tokens
}
, topk=
{
topk
}
\n
"
f
"CUDA expert_ids:
{
expert_ids_cuda
}
\n
"
f
"CUDA expert_ids:
{
expert_ids_cuda
}
\n
"
f
"Triton expert_ids:
{
expert_ids_triton
}
"
f
"Triton expert_ids:
{
expert_ids_triton
}
"
)
)
assert
torch
.
allclose
(
num_tokens_post_pad_cuda
,
num_tokens_post_pad_triton
),
(
assert
torch
.
allclose
(
num_tokens_post_pad_cuda
,
num_tokens_post_pad_triton
,
atol
=
0
,
rtol
=
0
),
(
f
"Num tokens post pad mismatch for block_size=
{
block_size
}
, "
f
"Num tokens post pad mismatch for block_size=
{
block_size
}
, "
f
"num_tokens=
{
num_tokens
}
, topk=
{
topk
}
\n
"
f
"num_tokens=
{
num_tokens
}
, topk=
{
topk
}
\n
"
f
"CUDA num_tokens_post_pad:
{
num_tokens_post_pad_cuda
}
\n
"
f
"CUDA num_tokens_post_pad:
{
num_tokens_post_pad_cuda
}
\n
"
f
"Triton num_tokens_post_pad:
{
num_tokens_post_pad_triton
}
"
f
"Triton num_tokens_post_pad:
{
num_tokens_post_pad_triton
}
"
)
)
# Select an expert to check
expert_idx
=
expert_ids_cuda
.
max
().
item
()
# Get the first and last block id where expert_ids_cuda == expert_idx
matching_indices
=
torch
.
where
(
expert_ids_cuda
==
expert_idx
)[
0
]
block_sorted_start
=
matching_indices
[
0
].
item
()
*
block_size
block_sorted_end
=
min
(
(
matching_indices
[
-
1
].
item
()
+
1
)
*
block_size
,
max_num_tokens_padded
)
selected_sorted_ids_cuda
=
sorted_ids_cuda
[
block_sorted_start
:
block_sorted_end
].
sort
()[
0
]
selected_sorted_ids_triton
=
sorted_ids_triton
[
block_sorted_start
:
block_sorted_end
].
sort
()[
0
]
assert
torch
.
allclose
(
selected_sorted_ids_cuda
,
selected_sorted_ids_triton
,
atol
=
0
,
rtol
=
0
,
),
(
f
"Sorted IDs mismatch for block_size=
{
block_size
}
, "
f
"num_tokens=
{
num_tokens
}
, topk=
{
topk
}
\n
"
f
"CUDA sorted_ids:
{
selected_sorted_ids_cuda
}
\n
"
f
"Triton sorted_ids:
{
selected_sorted_ids_triton
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment