Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e3c76844
Commit
e3c76844
authored
Nov 26, 2025
by
lizhigong
Browse files
Merge branch 'v0.5.4_dev_yiqa' into 'v0.5.4_dev'
V0.5.4 dev yiqa See merge request OpenDAS/sglang!40
parents
078de197
0a71d6b1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
48 deletions
+40
-48
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+27
-29
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+8
-15
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
...n/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
+5
-4
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
e3c76844
...
...
@@ -3,6 +3,7 @@ from __future__ import annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
collections
import
defaultdict
from
sglang.srt.distributed
import
get_moe_expert_parallel_rank
,
get_moe_expert_parallel_world_size
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin
import
SlimQuantCompressedTensorsMarlinConfig
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
import
torch
...
...
@@ -119,8 +120,8 @@ def fuse_silu_mul_quant_ep_wrapper(
topk
,
expect_m
)
def
fuse_silu_mul_quant_ep_fake
(
input
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -695,6 +696,11 @@ class DeepEPMoE(EPMoE):
if
all_tokens
<=
0
:
return
hidden_states
.
bfloat16
()
rank_expert_offset
=
get_moe_expert_parallel_rank
()
*
(
self
.
num_experts
//
get_moe_expert_parallel_world_size
())
topk_idx
=
torch
.
where
(
topk_idx
==
-
1
,
self
.
num_experts
-
1
if
rank_expert_offset
==
0
else
0
,
topk_idx
+
rank_expert_offset
)
expert_output
=
self
.
quant_method
.
apply_ep
(
x
=
hidden_states
,
w1
=
self
.
w13_weight
,
...
...
@@ -708,6 +714,7 @@ class DeepEPMoE(EPMoE):
use_nn_moe
=
False
,
w1_scale
=
self
.
w13_weight_scale
,
w2_scale
=
self
.
w2_weight_scale
,
a1_scale
=
hidden_states_scale
,
routed_scaling_factor
=
self
.
moe_runner_config
.
routed_scaling_factor
,
)
return
expert_output
...
...
@@ -740,10 +747,9 @@ class DeepEPMoE(EPMoE):
active_experts
.
add
(
e
)
token_expert_pos
[
t
]
=
lst
active_experts
=
sorted
(
list
(
active_experts
))
num_active
=
len
(
active_experts
)
if
num_active
==
0
:
if
not
active_experts
:
return
hidden_states
.
bfloat16
()
active_experts
=
sorted
(
list
(
active_experts
))
counts
=
defaultdict
(
int
)
for
t
in
range
(
M
):
...
...
@@ -752,12 +758,9 @@ class DeepEPMoE(EPMoE):
per_expert_block
=
{}
for
e
in
active_experts
:
cnt
=
counts
.
get
(
e
,
0
)
if
cnt
<=
0
:
per_expert_block
[
e
]
=
0
else
:
needed
=
((
cnt
+
256
-
1
)
//
256
)
*
256
# next multiple of 256
per_expert_block
[
e
]
=
max
(
256
,
needed
)
cnt
=
counts
[
e
]
needed
=
((
cnt
+
255
)
//
256
)
*
256
# same as ceil(cnt/256)*256
per_expert_block
[
e
]
=
max
(
256
,
needed
)
expert_slot_offset
=
{}
offset
=
0
...
...
@@ -766,7 +769,8 @@ class DeepEPMoE(EPMoE):
offset
+=
per_expert_block
[
e
]
pad_M
=
offset
hidden_states_packed
=
torch
.
zeros
((
pad_M
,
K
),
device
=
device
,
dtype
=
hidden_states
.
dtype
)
hidden_states_packed
=
torch
.
empty
((
pad_M
,
K
),
device
=
device
,
dtype
=
hidden_states
.
dtype
)
hidden_states_scale_packed
=
torch
.
empty
((
pad_M
,),
device
=
device
,
dtype
=
hidden_states_scale
.
dtype
)
m_indices
=
torch
.
full
((
pad_M
,),
-
1
,
device
=
device
,
dtype
=
torch
.
int32
)
slot_counters
=
{
e
:
0
for
e
in
active_experts
}
...
...
@@ -776,26 +780,27 @@ class DeepEPMoE(EPMoE):
for
(
e
,
pos
)
in
token_expert_pos
[
t
]:
start
=
expert_slot_offset
[
e
]
slot
=
slot_counters
[
e
]
if
slot
>=
per_expert_block
[
e
]:
raise
RuntimeError
(
f
"Internal error: expert
{
e
}
slot
{
slot
}
>= block
{
per_expert_block
[
e
]
}
"
)
row
=
start
+
slot
hidden_states_packed
[
row
]
=
hidden_states
[
t
]
m_indices
[
row
]
=
int
(
e
)
hidden_states_scale_packed
[
row
]
=
hidden_states_scale
[
t
]
m_indices
[
row
]
=
e
slot_counters
[
e
]
+=
1
w
=
topk_weights
[
t
,
pos
].
to
(
device
=
device
)
# record weight (as float32 on device)
w
=
topk_weights
[
t
,
pos
]
w_f
=
w
.
float
()
if
w
.
dtype
!=
torch
.
float32
else
w
token_row_weight_list
[
t
].
append
((
row
,
w_f
))
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states_packed
)
#
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed)
N
=
self
.
w13_weight
.
size
(
1
)
gateup_output
=
torch
.
empty
((
pad_M
,
N
*
16
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
q_a1_all
,
q_a1_scale
),
(
hidden_states_packed
,
hidden_states_scale_packed
),
(
self
.
w13_weight
,
self
.
w13_weight_scale
),
gateup_output
,
m_indices
,
)
del
hidden_states_packed
,
hidden_states_scale_packed
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant
(
gateup_output
)
down_output
=
torch
.
empty
((
pad_M
,
K
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
down_output
=
m_grouped_w8a8_gemm_nt_contig_asm
(
...
...
@@ -806,17 +811,10 @@ class DeepEPMoE(EPMoE):
)
result
=
torch
.
zeros
((
M
,
K
),
device
=
device
,
dtype
=
down_output
.
dtype
)
for
t
in
range
(
M
):
pairs
=
token_row_weight_list
[
t
]
if
not
pairs
:
continue
acc
=
None
for
(
row
,
w
)
in
pairs
:
vec
=
down_output
[
row
].
float
()
weighted
=
vec
*
w
acc
=
weighted
if
acc
is
None
else
(
acc
+
weighted
)
result
[
t
]
=
acc
.
to
(
result
.
dtype
)
return
result
for
(
row
,
w
)
in
token_row_weight_list
[
t
]:
result
[
t
].
addcmul_
(
down_output
[
row
].
float
(),
w
)
return
result
.
to
(
down_output
.
dtype
)
def
forward_deepgemm_contiguous
(
self
,
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
e3c76844
...
...
@@ -4,7 +4,6 @@ import logging
from
contextlib
import
nullcontext
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
from
sglang.srt.distributed
import
get_moe_expert_parallel_rank
,
get_moe_expert_parallel_world_size
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers.dp_attention
import
get_is_extend_in_batch
...
...
@@ -30,7 +29,7 @@ from sglang.srt.utils import (
is_npu
,
load_json_config
,
)
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
_is_npu
=
is_npu
()
if
TYPE_CHECKING
:
...
...
@@ -369,6 +368,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
# scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
# scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
# )
hidden_states
=
per_token_quant_int8
(
hidden_states
)
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
hidden_states
,
topk_ids
,
topk_weights
,
previous_event
...
...
@@ -441,19 +441,12 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expert_alignment
=
1
,
config
=
DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
)
if
self
.
quant_config
.
get
(
"quant_method"
)
==
"slimquant_w4a8_marlin"
:
self
.
rank_expert_offset
=
get_moe_expert_parallel_rank
()
*
(
self
.
num_experts
//
get_moe_expert_parallel_world_size
())
recv_topk_ids
=
torch
.
where
(
recv_topk_ids
==
-
1
,
self
.
num_experts
-
1
if
self
.
rank_expert_offset
==
0
else
0
,
recv_topk_ids
+
self
.
rank_expert_offset
)
else
:
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
num_recv_tokens_per_expert
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_expert
=
num_tokens_per_expert
,
)
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
num_recv_tokens_per_expert
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_expert
=
num_tokens_per_expert
,
)
return
(
recv_x
,
recv_topk_ids
,
...
...
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
View file @
e3c76844
...
...
@@ -213,7 +213,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
@
torch
.
_dynamo
.
disable
()
# TODO: 性能优化需lmslim/lightop配合
def
apply
(
self
,
...
...
@@ -307,7 +307,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# use_nn_moe: Optional[bool] = False,
# routed_scaling_factor: Optional[float] = None,
# use_fused_gate: Optional[bool] = False,
# **_
# **_
# ) -> torch.Tensor:
# from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
# from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...
...
@@ -351,8 +351,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# a2_scale=layer.w2_input_scale,
# use_nn_moe=use_nn_moe,
# )
def
apply_ep
(
self
,
def
apply_ep
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
@@ -392,6 +392,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
global_num_experts
=
global_num_experts
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment