Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d5538a81
Commit
d5538a81
authored
Apr 23, 2026
by
王敏
Browse files
[Feature]w4a8适配低延迟模式
parent
aef3c487
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
112 additions
and
8 deletions
+112
-8
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+16
-1
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+22
-0
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+74
-7
No files found.
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
d5538a81
...
@@ -38,7 +38,7 @@ from vllm.utils.math_utils import cdiv, round_up
...
@@ -38,7 +38,7 @@ from vllm.utils.math_utils import cdiv, round_up
from
vllm.utils.import_utils
import
has_deep_gemm
from
vllm.utils.import_utils
import
has_deep_gemm
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
lightop
import
fuse_silu_mul_quant_ep
from
lightop
import
fuse_silu_mul_quant_ep
,
m_grouped_w4a8_gemm_nt_masked
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
if
has_deep_gemm
():
if
has_deep_gemm
():
from
deepgemm
import
m_grouped_w8a8_gemm_nt_masked
from
deepgemm
import
m_grouped_w8a8_gemm_nt_masked
...
@@ -650,5 +650,20 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -650,5 +650,20 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# a2q, a2q_scale = per_token_quant_int8(act_out)
# a2q, a2q_scale = per_token_quant_int8(act_out)
# moe_grouped_gemm(a2q, w2, a2q_scale, self.w2_scale, expert_num_tokens, output)
# moe_grouped_gemm(a2q, w2, a2q_scale, self.w2_scale, expert_num_tokens, output)
elif
self
.
quant_config
.
use_int4_w4a8
:
m_grouped_w4a8_gemm_nt_masked
((
a1q
,
a1q_scale
),
(
w1
,
self
.
w1_scale
),
workspace1
,
expert_num_tokens
,
expected_m
,
)
a2q
,
a2q_scale
=
fuse_silu_mul_quant_ep
(
workspace1
,
expert_num_tokens
)
m_grouped_w4a8_gemm_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
self
.
w2_scale
),
output
,
expert_num_tokens
,
expected_m
)
else
:
else
:
raise
ValueError
(
f
"Unsupported dtype
{
self
.
quant_config
.
quant_dtype
}
"
)
raise
ValueError
(
f
"Unsupported dtype
{
self
.
quant_config
.
quant_dtype
}
"
)
vllm/model_executor/layers/fused_moe/config.py
View file @
d5538a81
...
@@ -579,6 +579,28 @@ def int8_w8a8_moe_quant_config(
...
@@ -579,6 +579,28 @@ def int8_w8a8_moe_quant_config(
block_shape
=
block_shape
,
block_shape
=
block_shape
,
)
)
def
int8_w4a8_moe_quant_config
(
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a1_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
per_act_token_quant
:
bool
=
False
,
block_shape
:
list
[
int
]
|
None
=
None
,
)
->
FusedMoEQuantConfig
:
"""
Construct a quant config for int8 activations and int8 weights.
"""
return
FusedMoEQuantConfig
.
make
(
torch
.
int8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
per_act_token_quant
=
per_act_token_quant
,
per_out_ch_quant
=
False
,
block_shape
=
block_shape
,
)
def
gptq_marlin_moe_quant_config
(
def
gptq_marlin_moe_quant_config
(
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
View file @
d5538a81
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
os
import
os
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
,
get_ep_group
,
get_dp_group
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.w4a8_utils
import
w4a8_weight_repack_impl
from
vllm.model_executor.layers.quantization.utils.w4a8_utils
import
w4a8_weight_repack_impl
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoeWeightScaleSupported
,
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,
FusedMoEActivationFormat
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
int8_w4a8_moe_quant_config
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
)
FusedMoEModularKernel
)
from
vllm.model_executor.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
from
vllm.model_executor.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
...
@@ -23,6 +29,8 @@ try:
...
@@ -23,6 +29,8 @@ try:
except
Exception
:
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
logger
=
init_logger
(
__name__
)
class
MarlinMoeWorkspace
:
class
MarlinMoeWorkspace
:
"""
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
...
@@ -149,9 +157,30 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -149,9 +157,30 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self
.
moe_quant_config
:
Optional
[
FusedMoEQuantConfig
]
=
None
self
.
moe_quant_config
:
Optional
[
FusedMoEQuantConfig
]
=
None
self
.
moe_mk
:
Optional
[
FusedMoEModularKernel
]
=
None
self
.
moe_mk
:
Optional
[
FusedMoEModularKernel
]
=
None
vllm_config
=
get_current_vllm_config
()
parallel_config
=
vllm_config
.
parallel_config
self
.
dp_size
=
get_dp_group
().
world_size
self
.
use_deepep
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
)
if
self
.
use_deepep
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
self
.
num_dispatchers
=
all2all_manager
.
world_size
def
get_fused_moe_quant_config
(
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
:
self
,
layer
:
torch
.
nn
.
Module
return
None
)
->
FusedMoEQuantConfig
|
None
:
return
int8_w4a8_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
per_act_token_quant
=
True
,
block_shape
=
[
256
,
256
]
if
self
.
use_deepep
else
None
,
)
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -162,7 +191,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -162,7 +191,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
use_deepep
:
self
.
N
=
2
*
intermediate_size
self
.
K
=
hidden_size
# WEIGHTS
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
...
@@ -251,3 +282,39 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -251,3 +282,39 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
shared_output
=
shared_output
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
layer
:
torch
.
nn
.
Module
,
)
->
FusedMoEPermuteExpertsUnpermute
:
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
BatchedDeepGemmExperts
,
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
DeepGemmExperts
,
)
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
max_num_tokens_per_rank
=
prepare_finalize
.
max_num_tokens_per_rank
()
assert
max_num_tokens_per_rank
is
not
None
logger
.
debug
(
"BatchedDeepGemmExperts(%s)"
,
self
.
__class__
.
__name__
)
return
BatchedDeepGemmExperts
(
moe_config
=
self
.
moe
,
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
self
.
moe_quant_config
,
N
=
self
.
N
,
K
=
self
.
K
)
else
:
logger
.
debug
(
"DeepGemmExperts(%s)"
,
self
.
__class__
.
__name__
)
return
DeepGemmExperts
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
N
=
self
.
N
,
K
=
self
.
K
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment