Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ba6f2101
Commit
ba6f2101
authored
Apr 24, 2026
by
chenhw5
Committed by
zhangzbb
Apr 24, 2026
Browse files
[FEATURE] GLM5 FP8 EP适配
parent
e7dee10f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
88 additions
and
26 deletions
+88
-26
vllm/model_executor/layers/fused_moe/all2all_utils.py
vllm/model_executor/layers/fused_moe/all2all_utils.py
+2
-5
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+10
-11
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+3
-3
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+2
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+71
-5
No files found.
vllm/model_executor/layers/fused_moe/all2all_utils.py
View file @
ba6f2101
...
...
@@ -157,11 +157,8 @@ def maybe_make_prepare_finalize(
# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch
=
(
quant_config
.
quant_dtype
==
current_platform
.
fp8_dtype
()
and
quant_config
.
block_shape
==
DEEPEP_QUANT_BLOCK_SHAPE
)
use_fp8_dispatch
=
quant_config
.
quant_dtype
==
current_platform
.
fp8_dtype
()
use_int8_dispatch
=
quant_config
.
quant_dtype
==
torch
.
int8
prepare_finalize
=
DeepEPLLPrepareAndFinalize
(
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
ba6f2101
...
...
@@ -38,10 +38,10 @@ from vllm.utils.math_utils import cdiv, round_up
from
vllm.utils.import_utils
import
has_deep_gemm
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
lightop
import
fuse_silu_mul_quant_ep
from
lightop
import
fuse_silu_mul_quant_ep
,
fuse_silu_mul_fp8_quant_ep
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
if
has_deep_gemm
():
from
deepgemm
import
m_grouped_w8a8_gemm_nt_masked
from
deepgemm
import
m_grouped_w8a8_gemm_nt_masked
,
m_grouped_fp8_gemm_nt_masked
else
:
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
...
...
@@ -452,8 +452,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
num_dispatchers
=
num_dispatchers
,
)
if
quant_config
.
use_fp8_w8a8
:
assert
self
.
block_shape
==
get_mk_alignment_for_contiguous_layout
()
#
if quant_config.use_fp8_w8a8:
#
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
self
.
N
=
N
self
.
K
=
K
...
...
@@ -606,7 +606,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m
=
self
.
get_expected_m
()
if
self
.
quant_config
.
use_fp8_w8a16
or
self
.
quant_config
.
use_fp8_w8a8
:
fp8_
m_grouped_gemm_nt_masked
(
m_grouped_
fp8_
gemm_nt_masked
(
(
a1q
,
a1q_scale
),
(
w1
,
self
.
w1_scale
),
workspace1
,
...
...
@@ -614,14 +614,13 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m
,
)
quant_scale_fmt
=
DeepGemmQuantScaleFMT
.
from_oracle
()
a2q
,
a2q_scale
=
persistent_masked_m_silu_mul_quant
(
workspace1
,
expert_num_tokens
,
quant_scale_fmt
=
quant_scale_fmt
,
a2q
,
a2q_scale
=
fuse_silu_mul_fp8_quant_ep
(
input
=
workspace1
,
fp8type
=
0
,
tokens_per_expert
=
expert_num_tokens
,
)
fp8_
m_grouped_gemm_nt_masked
(
m_grouped_
fp8_
gemm_nt_masked
(
(
a2q
,
a2q_scale
),
(
w2
,
self
.
w2_scale
),
output
,
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
ba6f2101
...
...
@@ -87,14 +87,14 @@ def _quant_flags_to_group_shape(
"""
a_shape
:
GroupShape
|
None
w_shape
:
GroupShape
|
None
if
block_shape
is
not
None
and
quant_dtype
!=
torch
.
int8
:
if
block_shape
is
not
None
and
quant_dtype
!=
torch
.
int8
and
quant_dtype
!=
current_platform
.
fp8_dtype
()
:
assert
not
per_act_token_quant
assert
not
per_out_ch_quant
# TODO(bnell): this is not quite right for activations since first
# dim should be 1.
a_shape
=
GroupShape
(
row
=
block_shape
[
0
],
col
=
block_shape
[
1
])
w_shape
=
GroupShape
(
row
=
block_shape
[
0
],
col
=
block_shape
[
1
])
elif
block_shape
is
not
None
and
quant_dtype
==
torch
.
int8
:
elif
block_shape
is
not
None
and
(
quant_dtype
==
torch
.
int8
or
quant_dtype
==
current_platform
.
fp8_dtype
())
:
a_shape
=
GroupShape
(
row
=
block_shape
[
0
],
col
=
block_shape
[
1
])
w_shape
=
GroupShape
(
row
=
block_shape
[
0
],
col
=
block_shape
[
1
])
else
:
...
...
@@ -518,7 +518,7 @@ class FusedMoEQuantConfig:
weight_dtype
,
w_shape
,
w2_scale
,
g2_alphas
,
w2_zp
,
w2_bias
),
)
if
quant_dtype
!=
torch
.
int8
:
if
quant_dtype
!=
torch
.
int8
and
quant_dtype
!=
current_platform
.
fp8_dtype
()
:
assert
quant_config
.
per_act_token_quant
==
per_act_token_quant
assert
quant_config
.
per_out_ch_quant
==
per_out_ch_quant
assert
quant_config
.
block_shape
==
block_shape
...
...
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
ba6f2101
...
...
@@ -22,7 +22,7 @@ from vllm.v1.worker.ubatching import (
dbo_enabled
,
dbo_maybe_run_recv_hook
,
)
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
...
...
@@ -179,7 +179,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
if
quant_config
.
block_shape
is
not
None
else
None
)
if
block_k
==
DEEPEP_QUANT_BLOCK_SIZE
:
if
block_k
==
DEEPEP_QUANT_BLOCK_SIZE
or
(
isinstance
(
x
,
tuple
)
and
x
[
0
].
dtype
==
current_platform
.
fp8_dtype
())
:
# DeepEP kernels did the quantization for us.
x
,
x_scales
=
x
return
x
,
x_scales
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
ba6f2101
...
...
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.fused_moe import (
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
get_w8a8_int8_marlin_weights
,
w8a8_nt_kpack2_marlin_weight
,
weight8bit_nt_kpack2_marlin1
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
int8_w8a8_moe_quant_config
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
int8_w8a8_moe_quant_config
,
fp8_w8a8_moe_quant_config
)
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
...
...
@@ -120,14 +120,38 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
"dynamic per token quantization. Found static input scales."
)
self
.
fused_experts
=
self
.
fused_moe_forward
vllm_config
=
get_current_vllm_config
()
parallel_config
=
vllm_config
.
parallel_config
self
.
dp_size
=
get_dp_group
().
world_size
self
.
use_deepep
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
)
if
self
.
use_deepep
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
self
.
num_dispatchers
=
all2all_manager
.
world_size
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
FusedMoEQuantConfig
]:
return
None
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
per_act_token_quant
=
True
,
per_out_ch_quant
=
False
,
block_shape
=
[
256
,
256
]
if
self
.
use_deepep
else
None
,
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
self
.
use_deepep
:
self
.
N
=
2
*
intermediate_size_per_partition
self
.
K
=
hidden_size
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
...
...
@@ -200,7 +224,10 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
else
:
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
if
not
self
.
use_deepep
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_in
=
weight8bit_nt_kpack2_marlin1
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
.
float
()
if
w1_marlin_in
.
dtype
==
torch
.
float8_e4m3fn
else
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w1_marlin
=
fp32_to_fp8_e4m3fn
(
w1_marlin
)
...
...
@@ -208,7 +235,10 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
del
w1_marlin_list
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
if
not
self
.
use_deepep
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
weight8bit_nt_kpack2_marlin1
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
.
float
()
if
w2_marlin_in
.
dtype
==
torch
.
float8_e4m3fn
else
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
w2_marlin
=
fp32_to_fp8_e4m3fn
(
w2_marlin
)
...
...
@@ -328,6 +358,42 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
routed_scaling_factor
=
routed_scaling_factor
,
shared_output
=
shared_output
,
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
layer
:
torch
.
nn
.
Module
,
)
->
FusedMoEPermuteExpertsUnpermute
:
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
BatchedDeepGemmExperts
,
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
DeepGemmExperts
,
)
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
max_num_tokens_per_rank
=
prepare_finalize
.
max_num_tokens_per_rank
()
assert
max_num_tokens_per_rank
is
not
None
logger
.
debug
(
"BatchedDeepGemmExperts(%s)"
,
self
.
__class__
.
__name__
)
return
BatchedDeepGemmExperts
(
moe_config
=
self
.
moe
,
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
self
.
moe_quant_config
,
N
=
self
.
N
,
K
=
self
.
K
)
else
:
logger
.
debug
(
"DeepGemmExperts(%s)"
,
self
.
__class__
.
__name__
)
return
DeepGemmExperts
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
N
=
self
.
N
,
K
=
self
.
K
)
class
CompressedTensorsW8A8Int8MarlinMoEMethod
(
CompressedTensorsMarlinMoEMethod
):
def
__init__
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment