Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
82eccae4
"docs/vscode:/vscode.git/clone" did not exist on "1d34a19710c20bb27e1311326153c804903eb10f"
Unverified
Commit
82eccae4
authored
Jun 28, 2025
by
fzyzcjy
Committed by
GitHub
Jun 28, 2025
Browse files
Let ep_scatter support arbitrary strides / ue8m0 format (#7309)
parent
a8c10aee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
6 deletions
+22
-6
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+22
-6
No files found.
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
82eccae4
...
...
@@ -813,14 +813,17 @@ def _fwd_kernel_ep_scatter_2(
offset_in
=
tl
.
arange
(
0
,
HIDDEN_SIZE_PAD
)
mask
=
offset_in
<
HIDDEN_SIZE
offset
_in_s
=
tl
.
arange
(
0
,
SCALE_HIDDEN_SIZE_PAD
)
mask_s
=
offset
_in_s
<
SCALE_HIDDEN_SIZE
index
_in_s
=
tl
.
arange
(
0
,
SCALE_HIDDEN_SIZE_PAD
)
mask_s
=
index
_in_s
<
SCALE_HIDDEN_SIZE
for
token_id_int32
in
range
(
start_token_id
,
total_token_num
,
grid_num
):
token_id
=
token_id_int32
.
to
(
tl
.
int64
)
to_copy
=
tl
.
load
(
recv_x
+
token_id
*
recv_x_stride0
+
offset_in
,
mask
=
mask
)
to_copy_s
=
tl
.
load
(
recv_x_scale
+
token_id
*
recv_x_scale_stride0
+
offset_in_s
,
mask
=
mask_s
recv_x_scale
+
token_id
*
recv_x_scale_stride0
+
index_in_s
*
recv_x_scale_stride1
,
mask
=
mask_s
,
)
for
topk_idx_int32
in
tl
.
range
(
0
,
topk_num
,
1
,
num_stages
=
4
):
...
...
@@ -841,7 +844,11 @@ def _fwd_kernel_ep_scatter_2(
output_tensor_scale
+
dest_token_index
*
output_tensor_scale_stride0
)
tl
.
store
(
output_tensor_ptr
+
offset_in
,
to_copy
,
mask
=
mask
)
tl
.
store
(
output_tensor_scale_ptr
+
offset_in_s
,
to_copy_s
,
mask
=
mask_s
)
tl
.
store
(
output_tensor_scale_ptr
+
index_in_s
*
output_tensor_scale_stride1
,
to_copy_s
,
mask
=
mask_s
,
)
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
...
...
@@ -856,6 +863,7 @@ def ep_scatter(
output_tensor_scale
:
torch
.
Tensor
,
m_indices
:
torch
.
Tensor
,
output_index
:
torch
.
Tensor
,
scale_ue8m0
:
bool
=
False
,
):
BLOCK_E
=
128
# token num of per expert is aligned to 128
BLOCK_D
=
128
# block size of quantization
...
...
@@ -865,7 +873,15 @@ def ep_scatter(
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid
=
num_experts
scale_hidden_size
=
hidden_size
//
BLOCK_D
if
scale_ue8m0
:
# ue8m0 scales are packed here (4 scales per int32),
# hence the effective size of this dimension is divided by 4.
scale_hidden_size
=
ceil_div
(
scale_hidden_size
,
4
)
assert
m_indices
.
shape
[
0
]
%
BLOCK_E
==
0
assert
recv_x_scale
.
dtype
==
output_tensor_scale
.
dtype
assert
recv_x_scale
.
shape
[
1
]
==
output_tensor_scale
.
shape
[
1
]
==
scale_hidden_size
_fwd_kernel_ep_scatter_1
[(
grid
,)](
num_recv_tokens_per_expert
,
...
...
@@ -904,8 +920,8 @@ def ep_scatter(
num_warps
=
num_warps
,
HIDDEN_SIZE
=
hidden_size
,
HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
hidden_size
),
SCALE_HIDDEN_SIZE
=
hidden_size
//
BLOCK_D
,
SCALE_HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
hidden_size
//
BLOCK_D
),
SCALE_HIDDEN_SIZE
=
scale_
hidden_size
,
SCALE_HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
scale_
hidden_size
),
)
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment