Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
00c7b1ad
Unverified
Commit
00c7b1ad
authored
Jun 28, 2025
by
fzyzcjy
Committed by
GitHub
Jun 28, 2025
Browse files
Let EP prefill support new DeepGEMM (#7310)
parent
82eccae4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
8 deletions
+31
-8
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+24
-7
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+7
-1
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
00c7b1ad
...
@@ -46,6 +46,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
...
@@ -46,6 +46,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
DeepEPMode
,
DeepEPMode
,
ceil_div
,
dispose_tensor
,
dispose_tensor
,
get_bool_env_var
,
get_bool_env_var
,
is_hip
,
is_hip
,
...
@@ -1370,10 +1371,19 @@ class DeepEPMoE(EPMoE):
...
@@ -1370,10 +1371,19 @@ class DeepEPMoE(EPMoE):
device
=
hidden_states_fp8
.
device
,
device
=
hidden_states_fp8
.
device
,
dtype
=
hidden_states_fp8
.
dtype
,
dtype
=
hidden_states_fp8
.
dtype
,
),
),
torch
.
empty
(
(
# TODO check whether need `zeros`
torch
.
zeros
(
(
ceil_div
(
K
//
128
,
4
),
all_tokens
),
device
=
hidden_states_fp8
.
device
,
dtype
=
torch
.
int
,
).
transpose
(
0
,
1
)
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
else
torch
.
empty
(
(
all_tokens
,
K
//
128
),
(
all_tokens
,
K
//
128
),
device
=
hidden_states_fp8
.
device
,
device
=
hidden_states_fp8
.
device
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
)
),
),
]
]
m_indices
=
torch
.
empty
(
m_indices
=
torch
.
empty
(
...
@@ -1399,6 +1409,7 @@ class DeepEPMoE(EPMoE):
...
@@ -1399,6 +1409,7 @@ class DeepEPMoE(EPMoE):
input_tensor
[
1
],
input_tensor
[
1
],
m_indices
,
m_indices
,
output_index
,
output_index
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
)
dispose_tensor
(
hidden_states_fp8
)
dispose_tensor
(
hidden_states_fp8
)
...
@@ -1407,6 +1418,7 @@ class DeepEPMoE(EPMoE):
...
@@ -1407,6 +1418,7 @@ class DeepEPMoE(EPMoE):
device
=
hidden_states_fp8_device
,
device
=
hidden_states_fp8_device
,
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
)
)
if
not
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
input_tensor
[
1
]
=
tma_align_input_scale
(
input_tensor
[
1
])
input_tensor
[
1
]
=
tma_align_input_scale
(
input_tensor
[
1
])
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_contig
(
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_contig
(
input_tensor
,
self
.
w13_weight_fp8
,
gateup_output
,
m_indices
input_tensor
,
self
.
w13_weight_fp8
,
gateup_output
,
m_indices
...
@@ -1428,9 +1440,14 @@ class DeepEPMoE(EPMoE):
...
@@ -1428,9 +1440,14 @@ class DeepEPMoE(EPMoE):
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
)
)
down_input_fp8
,
down_input_scale
=
sglang_per_token_group_quant_fp8
(
down_input_fp8
,
down_input_scale
=
sglang_per_token_group_quant_fp8
(
down_input
,
scale_block_size
down_input
,
scale_block_size
,
column_major_scales
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
scale_tma_aligned
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
)
del
down_input
del
down_input
if
not
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
down_input_scale
=
tma_align_input_scale
(
down_input_scale
)
down_input_scale
=
tma_align_input_scale
(
down_input_scale
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_contig
(
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_contig
(
(
down_input_fp8
,
down_input_scale
),
(
down_input_fp8
,
down_input_scale
),
...
...
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
00c7b1ad
...
@@ -246,7 +246,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -246,7 +246,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
# TODO hard code 128 block quant,use fp8 communication
# TODO hard code 128 block quant,use fp8 communication
hidden_states
=
sglang_per_token_group_quant_fp8
(
hidden_states
,
128
)
hidden_states
=
sglang_per_token_group_quant_fp8
(
hidden_states
,
128
,
column_major_scales
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
scale_tma_aligned
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
return
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment