Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
43ff40f8
"vscode:/vscode.git/clone" did not exist on "8ad73394ea553452eea926c121ea07854a8b550a"
Commit
43ff40f8
authored
Nov 19, 2025
by
maxiao1
Committed by
lizhigong
Nov 19, 2025
Browse files
优化量化算子、可设置tp等于dp、优化调度层非pinned memory异步拷贝问题
parent
62d065ca
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
140 additions
and
9 deletions
+140
-9
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+131
-0
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+4
-4
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+1
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+4
-4
No files found.
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
43ff40f8
...
...
@@ -1013,3 +1013,134 @@ def zero_experts_compute_triton(
)
return
output
from
triton.language.extra
import
libdevice
from
typing
import
Optional
@
triton
.
jit
def
_per_token_quant_int8_one_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
T_dim
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
row_id
=
tl
.
program_id
(
0
)
if
tokens_per_expert_ptr
is
not
None
:
e
=
row_id
//
T_dim
t
=
row_id
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
return
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
@
triton
.
jit
def
_per_token_quant_int8_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
E_dim
,
T_dim
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
token_idx_start
=
tl
.
program_id
(
0
)
grid_size
=
tl
.
num_programs
(
0
)
num_total_tokens
=
E_dim
*
T_dim
for
token_idx
in
range
(
token_idx_start
,
num_total_tokens
,
grid_size
):
is_valid_token
=
True
if
tokens_per_expert_ptr
is
not
None
:
e
=
token_idx
//
T_dim
t
=
token_idx
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
is_valid_token
=
False
if
is_valid_token
:
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
token_idx
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
token_idx
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
token_idx
,
scale_x
)
def
per_token_quant_int8_triton_opt
(
x
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
):
if
x
.
dim
()
!=
3
:
raise
ValueError
(
f
"Input must be 3D [E, T, H], but got
{
x
.
shape
}
"
)
E
,
T
,
H
=
x
.
shape
N
=
H
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
if
(
E
==
8
and
T
>=
1024
)
or
(
E
==
16
and
T
>=
512
):
num_warps
=
1
num_tokens
=
E
*
T
grid_opt
=
num_tokens
if
(
E
==
8
and
T
>=
1024
)
or
(
E
==
16
and
T
>=
512
):
grid_opt
=
max
(
1
,
num_tokens
//
(
T
//
256
))
_per_token_quant_int8_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
E_dim
=
E
,
T_dim
=
T
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
else
:
_per_token_quant_int8_one_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
T_dim
=
T
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
\ No newline at end of file
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
43ff40f8
...
...
@@ -20,6 +20,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
ep_scatter
,
silu_and_mul_masked_post_quant_fwd
,
tma_align_input_scale
,
per_token_quant_int8_triton_opt
,
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFusedMoE
,
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopKOutput
...
...
@@ -902,7 +903,7 @@ class DeepEPMoE(EPMoE):
expected_m
=
min
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states
)
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
_triton_opt
(
hidden_states
,
masked_m
)
# ---- weights & scales ----
w13_weight
=
self
.
w13_weight
...
...
@@ -943,16 +944,15 @@ class DeepEPMoE(EPMoE):
dispatch_output
:
DeepEPLLOutput
,
):
hidden_states
,
_
,
_
,
_
,
masked_m
,
expected_m
=
dispatch_output
hidden_states
,
_
,
topk_ids
,
_
,
masked_m
,
expected_m
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
# base shapes
num_groups
,
m
,
k
=
hidden_states
.
size
()
expected_m
=
min
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states
)
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
_triton_opt
(
hidden_states
,
masked_m
)
# ---- weights & scales ----
w13_weight
=
self
.
w13_weight
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
43ff40f8
...
...
@@ -308,7 +308,7 @@ class _DeepEPDispatcherImplBase:
self
.
params_bytes
=
2
self
.
num_max_dispatch_tokens_per_rank
=
get_int_env_var
(
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"
,
128
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"
,
64
)
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
43ff40f8
...
...
@@ -127,7 +127,7 @@ class ForwardMode(IntEnum):
# For fixed shape logits output in v2 eagle worker
return
self
==
ForwardMode
.
DRAFT_EXTEND_V2
def
is_extend_or_draft_extend_or_mixed
(
self
):
#nhb
def
is_extend_or_draft_extend_or_mixed
(
self
):
return
(
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
...
...
@@ -375,7 +375,7 @@ class ForwardBatch:
if
enable_num_token_non_padded
(
model_runner
.
server_args
):
ret
.
num_token_non_padded
=
torch
.
tensor
(
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
).
pin_memory
(
).
to
(
device
,
non_blocking
=
True
)
ret
.
num_token_non_padded_cpu
=
len
(
batch
.
input_ids
)
# For MLP sync
...
...
@@ -395,12 +395,12 @@ class ForwardBatch:
ret
.
global_num_tokens_cpu
=
global_num_tokens
ret
.
global_num_tokens_gpu
=
torch
.
tensor
(
global_num_tokens
,
dtype
=
torch
.
int64
).
to
(
device
,
non_blocking
=
True
)
).
pin_memory
(
).
to
(
device
,
non_blocking
=
True
)
ret
.
global_num_tokens_for_logprob_cpu
=
global_num_tokens_for_logprob
ret
.
global_num_tokens_for_logprob_gpu
=
torch
.
tensor
(
global_num_tokens_for_logprob
,
dtype
=
torch
.
int64
).
to
(
device
,
non_blocking
=
True
)
).
pin_memory
(
).
to
(
device
,
non_blocking
=
True
)
if
ret
.
forward_mode
.
is_idle
():
ret
.
positions
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment