Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e3b8a722
Unverified
Commit
e3b8a722
authored
May 18, 2025
by
xutizhou
Committed by
GitHub
May 17, 2025
Browse files
[fix] illegal memory in _fwd_kernel_ep_scatter_2 and _fwd_kernel_ep_gather (#6348)
parent
3cf1473a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
9 deletions
+23
-9
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+23
-9
No files found.
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
e3b8a722
...
@@ -791,19 +791,23 @@ def _fwd_kernel_ep_scatter_2(
...
@@ -791,19 +791,23 @@ def _fwd_kernel_ep_scatter_2(
offset_in_s
=
tl
.
arange
(
0
,
SCALE_HIDDEN_SIZE_PAD
)
offset_in_s
=
tl
.
arange
(
0
,
SCALE_HIDDEN_SIZE_PAD
)
mask_s
=
offset_in_s
<
SCALE_HIDDEN_SIZE
mask_s
=
offset_in_s
<
SCALE_HIDDEN_SIZE
for
token_id
in
range
(
start_token_id
,
total_token_num
,
grid_num
):
for
token_id_int32
in
range
(
start_token_id
,
total_token_num
,
grid_num
):
token_id
=
token_id_int32
.
to
(
tl
.
int64
)
to_copy
=
tl
.
load
(
recv_x
+
token_id
*
recv_x_stride0
+
offset_in
,
mask
=
mask
)
to_copy
=
tl
.
load
(
recv_x
+
token_id
*
recv_x_stride0
+
offset_in
,
mask
=
mask
)
to_copy_s
=
tl
.
load
(
to_copy_s
=
tl
.
load
(
recv_x_scale
+
token_id
*
recv_x_scale_stride0
+
offset_in_s
,
mask
=
mask_s
recv_x_scale
+
token_id
*
recv_x_scale_stride0
+
offset_in_s
,
mask
=
mask_s
)
)
for
topk_index
in
tl
.
range
(
0
,
topk_num
,
1
,
num_stages
=
4
):
for
topk_idx_int32
in
tl
.
range
(
0
,
topk_num
,
1
,
num_stages
=
4
):
topk_index
=
topk_idx_int32
.
to
(
tl
.
int64
)
expert_id
=
tl
.
load
(
recv_topk
+
token_id
*
recv_topk_stride0
+
topk_index
)
expert_id
=
tl
.
load
(
recv_topk
+
token_id
*
recv_topk_stride0
+
topk_index
)
if
expert_id
>=
0
:
if
expert_id
>=
0
:
dest_token_index
=
tl
.
atomic_add
(
expert_start_loc
+
expert_id
,
1
)
dest_token_index_int32
=
tl
.
atomic_add
(
expert_start_loc
+
expert_id
,
1
)
dest_token_index
=
dest_token_index_int32
.
to
(
tl
.
int64
)
tl
.
store
(
tl
.
store
(
output_index
+
token_id
*
output_index_stride0
+
topk_index
,
output_index
+
token_id
*
output_index_stride0
+
topk_index
,
dest_token_index
,
dest_token_index
_int32
,
)
)
output_tensor_ptr
=
(
output_tensor_ptr
=
(
output_tensor
+
dest_token_index
*
output_tensor_stride0
output_tensor
+
dest_token_index
*
output_tensor_stride0
...
@@ -902,21 +906,31 @@ def _fwd_kernel_ep_gather(
...
@@ -902,21 +906,31 @@ def _fwd_kernel_ep_gather(
topk_num
:
tl
.
constexpr
,
topk_num
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
):
cur_block
=
tl
.
program_id
(
0
)
cur_block_int32
=
tl
.
program_id
(
0
)
start_cur_token
=
tl
.
program_id
(
1
)
cur_block
=
cur_block_int32
.
to
(
tl
.
int64
)
start_cur_token_int32
=
tl
.
program_id
(
1
)
grid_num
=
tl
.
num_programs
(
1
)
grid_num
=
tl
.
num_programs
(
1
)
for
cur_token
in
range
(
start_cur_token
,
total_token_num
,
grid_num
):
for
cur_token_int32
in
range
(
start_cur_token_int32
,
total_token_num
,
grid_num
):
cur_token
=
cur_token_int32
.
to
(
tl
.
int64
)
off_d
=
tl
.
arange
(
0
,
BLOCK_D
)
off_d
=
tl
.
arange
(
0
,
BLOCK_D
)
accumulator
=
tl
.
zeros
([
BLOCK_D
],
dtype
=
tl
.
float32
)
accumulator
=
tl
.
zeros
([
BLOCK_D
],
dtype
=
tl
.
float32
)
for
topk_index
in
range
(
0
,
topk_num
):
for
topk_index_int32
in
range
(
0
,
topk_num
):
topk_index
=
topk_index_int32
.
to
(
tl
.
int64
)
expert_id
=
tl
.
load
(
expert_id
=
tl
.
load
(
recv_topk_ids
+
cur_token
*
recv_topk_ids_stride0
+
topk_index
recv_topk_ids
+
cur_token
*
recv_topk_ids_stride0
+
topk_index
)
)
if
expert_id
>=
0
:
if
expert_id
>=
0
:
source_token_index
=
tl
.
load
(
source_token_index
_int32
=
tl
.
load
(
input_index
+
cur_token
*
input_index_stride0
+
topk_index
input_index
+
cur_token
*
input_index_stride0
+
topk_index
)
)
source_token_index
=
source_token_index_int32
.
to
(
tl
.
int64
)
acc_weight
=
tl
.
load
(
acc_weight
=
tl
.
load
(
recv_topk_weight
+
cur_token
*
recv_topk_weight_stride0
+
topk_index
recv_topk_weight
+
cur_token
*
recv_topk_weight_stride0
+
topk_index
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment