Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bd63af06
Commit
bd63af06
authored
Nov 15, 2025
by
maxiao1
Browse files
Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'
适配w8a8_marlin 高吞吐模式 See merge request OpenDAS/sglang!25
parents
92f82dce
eed591c9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
5 deletions
+90
-5
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+69
-2
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+21
-3
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
bd63af06
...
@@ -3,6 +3,7 @@ from __future__ import annotations
...
@@ -3,6 +3,7 @@ from __future__ import annotations
import
logging
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin
import
SlimQuantCompressedTensorsMarlinConfig
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -39,7 +40,7 @@ if TYPE_CHECKING:
...
@@ -39,7 +40,7 @@ if TYPE_CHECKING:
DeepEPNormalOutput
,
DeepEPNormalOutput
,
DispatchOutput
,
DispatchOutput
,
)
)
from
lightop
import
m_grouped_w4a8_gemm_nt_masked
,
fuse_silu_mul_quant_ep
from
lightop
import
m_grouped_w4a8_gemm_nt_masked
,
fuse_silu_mul_quant_ep
,
m_grouped_w8a8_gemm_nt_masked
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
...
@@ -127,6 +128,7 @@ class EPMoE(FusedMoE):
...
@@ -127,6 +128,7 @@ class EPMoE(FusedMoE):
self
.
fp8_dtype
=
torch
.
float8_e4m3fn
self
.
fp8_dtype
=
torch
.
float8_e4m3fn
self
.
activation_scheme
=
quant_config
.
activation_scheme
self
.
activation_scheme
=
quant_config
.
activation_scheme
self
.
use_w4a8_marlin
=
False
self
.
use_w4a8_marlin
=
False
self
.
use_w8a8_marlin
=
False
elif
isinstance
(
quant_config
,
SlimQuantW4A8Int8MarlinConfig
):
elif
isinstance
(
quant_config
,
SlimQuantW4A8Int8MarlinConfig
):
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
block_shape
=
(
self
.
block_shape
=
(
...
@@ -137,12 +139,25 @@ class EPMoE(FusedMoE):
...
@@ -137,12 +139,25 @@ class EPMoE(FusedMoE):
self
.
use_fp8_w8a8
=
False
self
.
use_fp8_w8a8
=
False
self
.
activation_scheme
=
None
self
.
activation_scheme
=
None
self
.
use_w4a8_marlin
=
True
self
.
use_w4a8_marlin
=
True
self
.
use_w8a8_marlin
=
False
elif
isinstance
(
quant_config
,
SlimQuantCompressedTensorsMarlinConfig
):
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
block_shape
=
(
self
.
quant_method
.
quant_config
.
weight_block_size
if
self
.
use_block_quant
else
None
)
self
.
use_fp8_w8a8
=
False
self
.
activation_scheme
=
None
self
.
use_w4a8_marlin
=
False
self
.
use_w8a8_marlin
=
True
else
:
else
:
self
.
use_fp8_w8a8
=
False
self
.
use_fp8_w8a8
=
False
self
.
use_block_quant
=
False
self
.
use_block_quant
=
False
self
.
block_shape
=
None
self
.
block_shape
=
None
self
.
activation_scheme
=
None
self
.
activation_scheme
=
None
self
.
use_w4a8_marlin
=
False
self
.
use_w4a8_marlin
=
False
self
.
use_w8a8_marlin
=
False
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
):
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
:
...
@@ -498,6 +513,8 @@ class DeepEPMoE(EPMoE):
...
@@ -498,6 +513,8 @@ class DeepEPMoE(EPMoE):
elif
DispatchOutputChecker
.
format_is_deepep_ll
(
dispatch_output
):
elif
DispatchOutputChecker
.
format_is_deepep_ll
(
dispatch_output
):
if
self
.
use_w4a8_marlin
:
if
self
.
use_w4a8_marlin
:
return
self
.
forward_groupgemm_w4a8_marlin_masked
(
dispatch_output
)
return
self
.
forward_groupgemm_w4a8_marlin_masked
(
dispatch_output
)
elif
self
.
use_w8a8_marlin
:
return
self
.
forward_groupgemm_w8a8_marlin_masked
(
dispatch_output
)
else
:
else
:
if
(
if
(
get_moe_runner_backend
().
is_flashinfer_cutedsl
()
get_moe_runner_backend
().
is_flashinfer_cutedsl
()
...
@@ -783,7 +800,7 @@ class DeepEPMoE(EPMoE):
...
@@ -783,7 +800,7 @@ class DeepEPMoE(EPMoE):
# base shapes
# base shapes
num_groups
,
m
,
k
=
hidden_states
.
size
()
num_groups
,
m
,
k
=
hidden_states
.
size
()
expected_m
=
m
//
2
# 算子要求形状
expected_m
=
m
in
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states
)
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states
)
...
@@ -822,6 +839,56 @@ class DeepEPMoE(EPMoE):
...
@@ -822,6 +839,56 @@ class DeepEPMoE(EPMoE):
return
down_output
return
down_output
def
forward_groupgemm_w8a8_marlin_masked
(
self
,
dispatch_output
:
DeepEPLLOutput
,
):
hidden_states
,
_
,
_
,
_
,
masked_m
,
expected_m
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
# base shapes
num_groups
,
m
,
k
=
hidden_states
.
size
()
expected_m
=
min
(
m
,
expected_m
)
# ---- first quant: ensure float input for quantizer ----
q_a1_all
,
q_a1_scale
=
per_token_quant_int8
(
hidden_states
)
# ---- weights & scales ----
w13_weight
=
self
.
w13_weight
w13_scales
=
self
.
w13_weight_scale
w2_weight
=
self
.
w2_weight
w2_scales
=
self
.
w2_weight_scale
n1
=
w13_scales
.
size
(
1
)
gateup_output
=
torch
.
empty
((
num_groups
,
m
,
n1
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
bfloat16
)
# ---- first GEMM ----
m_grouped_w8a8_gemm_nt_masked
(
(
q_a1_all
,
q_a1_scale
),
(
w13_weight
,
w13_scales
),
gateup_output
,
masked_m
,
expected_m
,
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
# ---- second GEMM ----
n2
=
w2_scales
.
size
(
1
)
down_output
=
torch
.
empty
((
num_groups
,
m
,
n2
),
device
=
q_a2_all
.
device
,
dtype
=
torch
.
bfloat16
)
m_grouped_w8a8_gemm_nt_masked
(
(
q_a2_all
,
q_a2_scale
),
(
w2_weight
,
w2_scales
),
down_output
,
masked_m
,
expected_m
,
)
return
down_output
def
forward_deepgemm_masked
(
def
forward_deepgemm_masked
(
self
,
self
,
dispatch_output
:
DeepEPLLOutput
,
dispatch_output
:
DeepEPLLOutput
,
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
bd63af06
...
@@ -39,6 +39,18 @@ def get_w8a8_int8_marlin_weights(
...
@@ -39,6 +39,18 @@ def get_w8a8_int8_marlin_weights(
return
weight
return
weight
def
w8a8_nt_kpack2_marlin_weight
(
w8a8_w
,
# [size_n, size_k// 2 ]
k_tile
=
16
,
n_tile
=
16
,
):
assert
w8a8_w
.
dtype
==
torch
.
int8
,
"w8a8_w 必须是 int8 类型"
size_n
,
size_k
=
w8a8_w
.
shape
assert
size_n
%
k_tile
==
0
and
size_k
%
n_tile
==
0
,
"k_tile / n_tile 必须能整除对应维度"
q
=
w8a8_w
.
reshape
((
size_n
//
n_tile
,
n_tile
,
size_k
//
k_tile
,
k_tile
))
q
=
q
.
permute
((
0
,
2
,
1
,
3
)).
contiguous
()
q
=
q
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
return
q
class
CompressedTensorsMarlinMoEMethod
(
FusedMoEMethodBase
):
class
CompressedTensorsMarlinMoEMethod
(
FusedMoEMethodBase
):
@
staticmethod
@
staticmethod
def
get_moe_method
(
def
get_moe_method
(
...
@@ -65,7 +77,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -65,7 +77,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"weights"
)
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
"input_activations"
)
self
.
use_deepep
=
True
per_channel
=
(
per_channel
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
...
@@ -138,13 +150,19 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -138,13 +150,19 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
if
not
self
.
use_deepep
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w2_marlin_list
=
[]
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
if
not
self
.
use_deepep
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment