Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
43ff40f8
Commit
43ff40f8
authored
Nov 19, 2025
by
maxiao1
Committed by
lizhigong
Nov 19, 2025
Browse files
优化量化算子、可设置tp等于dp、优化调度层非pinned memory异步拷贝问题
parent
62d065ca
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
140 additions
and
9 deletions
+140
-9
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+131
-0
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+4
-4
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+1
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+4
-4
No files found.
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
43ff40f8
...
@@ -1013,3 +1013,134 @@ def zero_experts_compute_triton(
...
@@ -1013,3 +1013,134 @@ def zero_experts_compute_triton(
)
)
return
output
return
output
from
triton.language.extra
import
libdevice
from
typing
import
Optional
@
triton
.
jit
def
_per_token_quant_int8_one_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
T_dim
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
row_id
=
tl
.
program_id
(
0
)
if
tokens_per_expert_ptr
is
not
None
:
e
=
row_id
//
T_dim
t
=
row_id
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
return
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
@
triton
.
jit
def
_per_token_quant_int8_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
E_dim
,
T_dim
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
token_idx_start
=
tl
.
program_id
(
0
)
grid_size
=
tl
.
num_programs
(
0
)
num_total_tokens
=
E_dim
*
T_dim
for
token_idx
in
range
(
token_idx_start
,
num_total_tokens
,
grid_size
):
is_valid_token
=
True
if
tokens_per_expert_ptr
is
not
None
:
e
=
token_idx
//
T_dim
t
=
token_idx
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
is_valid_token
=
False
if
is_valid_token
:
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
token_idx
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
token_idx
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
token_idx
,
scale_x
)
def
per_token_quant_int8_triton_opt
(
x
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
):
if
x
.
dim
()
!=
3
:
raise
ValueError
(
f
"Input must be 3D [E, T, H], but got
{
x
.
shape
}
"
)
E
,
T
,
H
=
x
.
shape
N
=
H
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
if
(
E
==
8
and
T
>=
1024
)
or
(
E
==
16
and
T
>=
512
):
num_warps
=
1
num_tokens
=
E
*
T
grid_opt
=
num_tokens
if
(
E
==
8
and
T
>=
1024
)
or
(
E
==
16
and
T
>=
512
):
grid_opt
=
max
(
1
,
num_tokens
//
(
T
//
256
))
_per_token_quant_int8_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
E_dim
=
E
,
T_dim
=
T
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
else
:
_per_token_quant_int8_one_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
T_dim
=
T
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
\ No newline at end of file
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
43ff40f8
...
@@ -20,6 +20,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
...
@@ -20,6 +20,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
ep_scatter
,
ep_scatter
,
silu_and_mul_masked_post_quant_fwd
,
silu_and_mul_masked_post_quant_fwd
,
tma_align_input_scale
,
tma_align_input_scale
,
per_token_quant_int8_triton_opt
,
)
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFusedMoE
,
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFusedMoE
,
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.topk
import
TopKOutput
...
@@ -902,7 +903,7 @@ class DeepEPMoE(EPMoE):
...
@@ -902,7 +903,7 @@ class DeepEPMoE(EPMoE):
expected_m
=
min
(
m
,
expected_m
)
expected_m
=
min
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states
)
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
_triton_opt
(
hidden_states
,
masked_m
)
# ---- weights & scales ----
# ---- weights & scales ----
w13_weight
=
self
.
w13_weight
w13_weight
=
self
.
w13_weight
...
@@ -943,16 +944,15 @@ class DeepEPMoE(EPMoE):
...
@@ -943,16 +944,15 @@ class DeepEPMoE(EPMoE):
dispatch_output
:
DeepEPLLOutput
,
dispatch_output
:
DeepEPLLOutput
,
):
):
hidden_states
,
_
,
_
,
_
,
masked_m
,
expected_m
=
dispatch_output
hidden_states
,
_
,
topk_ids
,
_
,
masked_m
,
expected_m
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
assert
self
.
moe_runner_config
.
activation
==
"silu"
# base shapes
# base shapes
num_groups
,
m
,
k
=
hidden_states
.
size
()
num_groups
,
m
,
k
=
hidden_states
.
size
()
expected_m
=
min
(
m
,
expected_m
)
expected_m
=
min
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states
)
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
_triton_opt
(
hidden_states
,
masked_m
)
# ---- weights & scales ----
# ---- weights & scales ----
w13_weight
=
self
.
w13_weight
w13_weight
=
self
.
w13_weight
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
43ff40f8
...
@@ -308,7 +308,7 @@ class _DeepEPDispatcherImplBase:
...
@@ -308,7 +308,7 @@ class _DeepEPDispatcherImplBase:
self
.
params_bytes
=
2
self
.
params_bytes
=
2
self
.
num_max_dispatch_tokens_per_rank
=
get_int_env_var
(
self
.
num_max_dispatch_tokens_per_rank
=
get_int_env_var
(
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"
,
128
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"
,
64
)
)
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
43ff40f8
...
@@ -127,7 +127,7 @@ class ForwardMode(IntEnum):
...
@@ -127,7 +127,7 @@ class ForwardMode(IntEnum):
# For fixed shape logits output in v2 eagle worker
# For fixed shape logits output in v2 eagle worker
return
self
==
ForwardMode
.
DRAFT_EXTEND_V2
return
self
==
ForwardMode
.
DRAFT_EXTEND_V2
def
is_extend_or_draft_extend_or_mixed
(
self
):
#nhb
def
is_extend_or_draft_extend_or_mixed
(
self
):
return
(
return
(
self
==
ForwardMode
.
EXTEND
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
...
@@ -375,7 +375,7 @@ class ForwardBatch:
...
@@ -375,7 +375,7 @@ class ForwardBatch:
if
enable_num_token_non_padded
(
model_runner
.
server_args
):
if
enable_num_token_non_padded
(
model_runner
.
server_args
):
ret
.
num_token_non_padded
=
torch
.
tensor
(
ret
.
num_token_non_padded
=
torch
.
tensor
(
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
).
pin_memory
(
).
to
(
device
,
non_blocking
=
True
)
ret
.
num_token_non_padded_cpu
=
len
(
batch
.
input_ids
)
ret
.
num_token_non_padded_cpu
=
len
(
batch
.
input_ids
)
# For MLP sync
# For MLP sync
...
@@ -395,12 +395,12 @@ class ForwardBatch:
...
@@ -395,12 +395,12 @@ class ForwardBatch:
ret
.
global_num_tokens_cpu
=
global_num_tokens
ret
.
global_num_tokens_cpu
=
global_num_tokens
ret
.
global_num_tokens_gpu
=
torch
.
tensor
(
ret
.
global_num_tokens_gpu
=
torch
.
tensor
(
global_num_tokens
,
dtype
=
torch
.
int64
global_num_tokens
,
dtype
=
torch
.
int64
).
to
(
device
,
non_blocking
=
True
)
).
pin_memory
(
).
to
(
device
,
non_blocking
=
True
)
ret
.
global_num_tokens_for_logprob_cpu
=
global_num_tokens_for_logprob
ret
.
global_num_tokens_for_logprob_cpu
=
global_num_tokens_for_logprob
ret
.
global_num_tokens_for_logprob_gpu
=
torch
.
tensor
(
ret
.
global_num_tokens_for_logprob_gpu
=
torch
.
tensor
(
global_num_tokens_for_logprob
,
dtype
=
torch
.
int64
global_num_tokens_for_logprob
,
dtype
=
torch
.
int64
).
to
(
device
,
non_blocking
=
True
)
).
pin_memory
(
).
to
(
device
,
non_blocking
=
True
)
if
ret
.
forward_mode
.
is_idle
():
if
ret
.
forward_mode
.
is_idle
():
ret
.
positions
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
device
)
ret
.
positions
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment