Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
577eb49f
Commit
577eb49f
authored
Jan 04, 2026
by
wujl5
Committed by
zhuwenwen
Jan 04, 2026
Browse files
perf: DS-量化模型融合qa和kva的gemm
parent
d4e72be3
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
261 additions
and
23 deletions
+261
-23
vllm/envs.py
vllm/envs.py
+6
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+185
-1
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+14
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+56
-21
No files found.
vllm/envs.py
View file @
577eb49f
...
...
@@ -201,6 +201,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_FILL_RMS_CAT
:
bool
=
False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
:
bool
=
True
VLLM_ZERO_OVERHEAD_ENHANCE
:
bool
=
False
VLLM_USE_FUSED_QA_KVA_GEMM
:
bool
=
False
def
get_default_cache_root
():
return
os
.
getenv
(
...
...
@@ -1299,6 +1300,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM"
:
lambda
:
(
os
.
getenv
(
'VLLM_ENABLE_DEEPEP_HT_DEEPGEMM'
,
'1'
).
lower
()
in
(
"true"
,
"1"
)),
# Only quantized DeepSeek models supported.
# Unquantized versions are not supported.
"VLLM_USE_FUSED_QA_KVA_GEMM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_QA_KVA_GEMM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
"VLLM_ZERO_OVERHEAD_ENHANCE"
:
lambda
:
(
os
.
getenv
(
'VLLM_ZERO_OVERHEAD_ENHANCE'
,
'0'
).
lower
()
in
(
"true"
,
"1"
)),
...
...
vllm/model_executor/layers/linear.py
View file @
577eb49f
...
...
@@ -32,6 +32,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.platforms
import
current_platform
import
os
import
re
from
vllm.model_executor.utils
import
gemm_bank_conf
from
lmslim.quantize.quant_ops
import
lm_faster_rmsquant
from
lmslim.quantize.quant_ops
import
lm_fuse_silu_mul_quant
...
...
@@ -447,6 +448,189 @@ class ReplicatedLinear(LinearBase):
return
s
class
FusedQuantedReplicatedLinear
(
LinearBase
):
def
__init__
(
self
,
input_size
:
int
,
q_lora_rank
,
kv_lora_rank
,
qk_rope_head_dim
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
eps
:
Optional
[
float
]
=
1e-6
,
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
):
output_size
=
q_lora_rank
+
kv_lora_rank
+
qk_rope_head_dim
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
=
prefix
,
return_bias
=
return_bias
)
self
.
eps
=
eps
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
q_a_weight
=
None
self
.
kv_a_weight
=
None
self
.
q_a_wscale
=
None
self
.
kv_a_wscale
=
None
self
.
weight_loaded
=
False
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
self
.
layer_num
=
-
1
if
bias
:
logger
.
warning
(
"Quanted DeepSeek-specific implementation. "
"Bias is not currently supported."
)
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
self
.
params_dtype
))
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
,
"weight_loader"
:
self
.
weight_loader
,
})
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
):
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
if
is_gguf_weight
:
raise
ValueError
(
f
"Unexpected is_gguf_weight"
)
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
if
not
is_quantization
:
raise
RuntimeError
(
"Quanted DeepSeek-specific implementation."
"not support UnquantizedLinearMethod"
)
self
.
_record_layer_num
(
weight_name
)
if
"q_a_proj"
in
weight_name
:
self
.
_store_qa_weight
(
loaded_weight
,
weight_name
)
elif
"kv_a_proj"
in
weight_name
:
self
.
_store_kva_weight
(
loaded_weight
,
weight_name
)
if
self
.
_received_two_weight
():
self
.
_fused_quantized_weight
(
weight_name
,
param
)
def
_record_layer_num
(
self
,
source
:
str
):
pattern
=
r
"model\.layers\.(\d+)(?:\.\w+)?\.self_attn"
numbers
=
re
.
findall
(
pattern
,
source
)[
0
]
numbers
=
int
(
numbers
)
if
self
.
layer_num
==
-
1
:
self
.
layer_num
=
numbers
else
:
assert
self
.
layer_num
==
numbers
,
f
"self.layer_num:
{
self
.
layer_num
}
!= numbers:
{
numbers
}
\n
"
def
_store_qa_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
source
:
str
):
if
"zero"
in
source
:
raise
RuntimeError
(
"Unsupported zero point weight now."
)
if
"weight_scale"
in
source
:
self
.
q_a_wscale
=
loaded_weight
return
elif
"weight"
in
source
:
self
.
q_a_weight
=
loaded_weight
return
else
:
raise
ValueError
(
f
"Unexpected weight:
{
source
}
"
)
def
_store_kva_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
source
:
str
):
if
"zero"
in
source
:
raise
RuntimeError
(
"Unsupported zero point weight now."
)
if
"weight_scale"
in
source
:
self
.
kv_a_wscale
=
loaded_weight
return
elif
"weight"
in
source
:
self
.
kv_a_weight
=
loaded_weight
return
else
:
raise
ValueError
(
f
"Unexpected weight:
{
source
}
"
)
def
_received_two_weight
(
self
):
if
self
.
q_a_weight
is
not
None
and
self
.
kv_a_weight
is
not
None
:
return
True
if
self
.
q_a_wscale
is
not
None
and
self
.
kv_a_wscale
is
not
None
:
return
True
return
False
def
_fused_quantized_weight
(
self
,
source
:
str
,
param
:
Parameter
):
if
"weight_scale"
in
source
:
assert
len
(
self
.
q_a_wscale
.
shape
)
==
2
assert
len
(
self
.
kv_a_wscale
.
shape
)
==
2
fused_scale
=
torch
.
cat
([
self
.
q_a_wscale
,
self
.
kv_a_wscale
],
dim
=
0
)
assert
param
.
data
.
shape
==
fused_scale
.
shape
,
f
"
{
param
.
data
.
shape
}
==
{
fused_scale
.
shape
}
"
param
.
data
.
copy_
(
fused_scale
)
elif
"weight"
in
source
:
assert
len
(
self
.
q_a_weight
.
shape
)
==
2
assert
len
(
self
.
kv_a_weight
.
shape
)
==
2
fused_weight
=
torch
.
cat
([
self
.
q_a_weight
,
self
.
kv_a_weight
],
dim
=
0
)
# TN
param
.
data
.
copy_
(
fused_weight
)
#TODO: wjl 删掉无用的显存tensor
else
:
raise
ValueError
(
f
"Unexpected weight:
{
source
}
"
)
def
forward
(
self
,
input_
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
Parameter
],
list
[
torch
.
Tensor
]]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
i_q
,
_scales
=
lm_faster_rmsquant
(
input
=
input_
,
rms_weight
=
rms_weight
,
epsilon
=
self
.
eps
,
quant_dtype
=
torch
.
int8
,
residual
=
residual
,
update_input
=
update_hd
)
new_residual
=
residual
input_quant_args
=
[
i_q
,
_scales
]
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
assert
self
.
return_bias
is
True
if
not
self
.
return_bias
:
raise
RuntimeError
(
"Not return bias. Unexpected Error."
)
return
output
,
new_residual
,
output_bias
else
:
raise
RuntimeError
(
"Unexpected Error."
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
+=
f
", output_features=
{
self
.
output_size
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
return
s
class
ColumnParallelLinear
(
LinearBase
):
"""Linear layer with column parallelism.
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
577eb49f
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
re
import
vllm.envs
as
envs
from
collections.abc
import
Iterable
from
typing
import
Iterable
,
Optional
...
...
@@ -228,6 +229,12 @@ class DeepSeekMTP(nn.Module, SupportsPP):
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
if
envs
.
USE_FUSED_RMS_QUANT
and
envs
.
VLLM_USE_FUSED_QA_KVA_GEMM
:
fused_params_mapping
=
[
(
"qa_kva_proj"
,
"q_a_proj"
,
0
),
(
"qa_kva_proj"
,
"kv_a_proj_with_mqa"
,
1
)
]
stacked_params_mapping
+=
fused_params_mapping
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
...
...
@@ -256,6 +263,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
((
"mlp.experts."
in
name
)
and
name
not
in
params_dict
):
continue
old_weight_name
=
name
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
...
...
@@ -264,7 +272,12 @@ class DeepSeekMTP(nn.Module, SupportsPP):
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
if
envs
.
USE_FUSED_RMS_QUANT
and
envs
.
VLLM_USE_FUSED_QA_KVA_GEMM
and
((
"q_a_proj"
in
old_weight_name
)
or
(
"kv_a_proj_with_mqa"
in
old_weight_name
)):
weight_loader
(
param
,
loaded_weight
,
old_weight_name
)
else
:
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
577eb49f
...
...
@@ -52,6 +52,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
ReplicatedLinear
,
FusedQuantedReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
...
@@ -588,6 +589,15 @@ class DeepseekV2MLAAttention(nn.Module):
if
self
.
q_lora_rank
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
VLLM_USE_FUSED_QA_KVA_GEMM
:
self
.
qa_kva_proj
=
FusedQuantedReplicatedLinear
(
self
.
hidden_size
,
self
.
q_lora_rank
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qa_kva_proj"
)
else
:
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
q_lora_rank
,
bias
=
False
,
...
...
@@ -624,7 +634,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_proj"
)
if
not
envs
.
VLLM_USE_FUSED_QA_KVA_GEMM
:
self
.
kv_a_proj_with_mqa
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
...
...
@@ -688,6 +698,8 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
prefix
=
prefix
self
.
debug_layer_idx
=
int
(
self
.
prefix
.
split
(
"."
)[
-
2
])
# TODO wjl: 这里的forward拆了
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -703,13 +715,25 @@ class DeepseekV2MLAAttention(nn.Module):
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
if
envs
.
VLLM_USE_FUSED_QA_KVA_GEMM
:
if
self
.
q_lora_rank
is
not
None
:
qc_kvc_kpe
,
new_residual
,
_bias
=
self
.
qa_kva_proj
(
hidden_states
,
rms_weight
=
rms_weight
,
residual
=
residual
,
update_hd
=
False
)
q_c
=
qc_kvc_kpe
[:,
:
self
.
q_lora_rank
]
kvc_kpe
=
qc_kvc_kpe
[:,
self
.
q_lora_rank
:]
q
,
_
,
_
=
self
.
q_b_proj
(
q_c
,
rms_weight
=
self
.
q_a_layernorm
.
weight
.
data
,
residual
=
None
,
update_hd
=
False
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
kvc_kpe
.
split
([
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
else
:
if
self
.
q_lora_rank
is
not
None
:
q_c
,
new_residual
,
_
,
input_quant_args
=
self
.
q_a_proj
(
hidden_states
,
rms_weight
=
rms_weight
,
residual
=
residual
,
update_hd
=
False
)
q
,
_
,
_
=
self
.
q_b_proj
(
q_c
,
rms_weight
=
self
.
q_a_layernorm
.
weight
.
data
,
residual
=
None
,
update_hd
=
False
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
,
quant_args
=
input_quant_args
,
update_hd
=
False
)[
0
].
split
(
kvc_kpe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
,
quant_args
=
input_quant_args
,
update_hd
=
False
)[
0
]
kv_c
,
k_pe
=
kvc_kpe
.
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
...
...
@@ -1375,6 +1399,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
if
envs
.
USE_FUSED_RMS_QUANT
and
envs
.
VLLM_USE_FUSED_QA_KVA_GEMM
:
fused_params_mapping
=
[
(
"qa_kva_proj"
,
"q_a_proj"
,
0
),
(
"qa_kva_proj"
,
"kv_a_proj_with_mqa"
,
1
)
]
stacked_params_mapping
+=
fused_params_mapping
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
...
...
@@ -1407,6 +1437,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
((
"mlp.experts."
in
name
)
and
name
not
in
params_dict
):
continue
old_weight_name
=
name
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
...
...
@@ -1418,6 +1449,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
if
envs
.
USE_FUSED_RMS_QUANT
and
envs
.
VLLM_USE_FUSED_QA_KVA_GEMM
and
((
"q_a_proj"
in
old_weight_name
)
or
(
"kv_a_proj_with_mqa"
in
old_weight_name
)):
weight_loader
(
param
,
loaded_weight
,
old_weight_name
)
else
:
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment