Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a05d749e
Commit
a05d749e
authored
Apr 08, 2026
by
wujl5
Committed by
zhangzbb
Apr 08, 2026
Browse files
[BUGFIX] rms_quant融合功能适配DSA
parent
456e8c10
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
12 deletions
+25
-12
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+7
-3
vllm/model_executor/layers/mla.py
vllm/model_executor/layers/mla.py
+4
-3
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+14
-6
No files found.
vllm/model_executor/layers/linear.py
View file @
a05d749e
...
...
@@ -271,7 +271,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
**
_
)
->
torch
.
Tensor
:
if
self
.
use_llama_nn
:
# if os.environ['GEMM_PAD'] == '1' and gemm_bank_conf(layer.weight.shape[1] - 32):
...
...
@@ -458,11 +458,15 @@ class ReplicatedLinear(LinearBase):
def
forward
(
self
,
x
:
torch
.
Tensor
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
Parameter
|
None
]:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
x
,
bias
)
if
envs
.
USE_FUSED_RMS_QUANT
and
iqis
is
not
None
and
iqis
[
0
]
is
not
None
:
output
=
self
.
quant_method
.
apply
(
self
,
x
,
bias
,
input_quant_args
=
iqis
)
else
:
output
=
self
.
quant_method
.
apply
(
self
,
x
,
bias
)
if
not
self
.
return_bias
:
return
output
...
...
vllm/model_executor/layers/mla.py
View file @
a05d749e
...
...
@@ -177,9 +177,10 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
)
if
self
.
indexer
and
self
.
is_sparse
:
_topk_indices
=
self
.
indexer
(
hidden_states
,
q_c
,
positions
,
self
.
indexer_rope_emb
)
if
envs
.
USE_FUSED_RMS_QUANT
and
iqis
is
not
None
:
_topk_indices
=
self
.
indexer
(
hidden_states
,
q_c
,
positions
,
self
.
indexer_rope_emb
,
iqis
=
iqis
)
else
:
_topk_indices
=
self
.
indexer
(
hidden_states
,
q_c
,
positions
,
self
.
indexer_rope_emb
)
if
llama_4_scaling
is
not
None
:
q
*=
llama_4_scaling
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
a05d749e
...
...
@@ -730,15 +730,18 @@ class Indexer(nn.Module):
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
qr
:
torch
.
Tensor
,
positions
,
rotary_emb
self
,
hidden_states
:
torch
.
Tensor
,
qr
:
torch
.
Tensor
,
positions
,
rotary_emb
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
q
,
_
=
self
.
wq_b
(
qr
)
q
=
q
.
view
(
-
1
,
self
.
n_head
,
self
.
head_dim
)
q_pe
,
q_nope
=
torch
.
split
(
q
,
[
self
.
rope_dim
,
self
.
head_dim
-
self
.
rope_dim
],
dim
=-
1
)
k
,
_
=
self
.
wk
(
hidden_states
)
if
envs
.
USE_FUSED_RMS_QUANT
and
self
.
wk
.
weight
.
dtype
==
torch
.
int8
and
iqis
is
not
None
:
k
,
_
=
self
.
wk
(
hidden_states
,
iqis
=
iqis
)
else
:
k
,
_
=
self
.
wk
(
hidden_states
)
k
=
self
.
k_norm
(
k
)
k_pe
,
k_nope
=
torch
.
split
(
k
,
[
self
.
rope_dim
,
self
.
head_dim
-
self
.
rope_dim
],
dim
=-
1
...
...
@@ -770,7 +773,10 @@ class Indexer(nn.Module):
else
:
q_fp8
=
q
weights
,
_
=
self
.
weights_proj
(
hidden_states
)
if
envs
.
USE_FUSED_RMS_QUANT
and
self
.
weights_proj
.
weight
.
dtype
==
torch
.
int8
and
iqis
is
not
None
:
weights
,
_
=
self
.
weights_proj
(
hidden_states
,
iqis
=
iqis
)
else
:
weights
,
_
=
self
.
weights_proj
(
hidden_states
)
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
weights
=
(
weights
.
unsqueeze
(
-
1
)
*
q_scale
*
self
.
softmax_scale
*
self
.
n_head
**-
0.5
...
...
@@ -1073,19 +1079,21 @@ class DeepseekV2DecoderLayer(nn.Module):
# Fix residual FP16 overflow
residual_fix_overflow
=
False
assert
self
.
input_layernorm
.
has_weight
is
True
# DSA should set update_input True
_dsa_flag
=
hasattr
(
self
.
self_attn
,
"indexer"
)
and
self
.
self_attn
.
indexer
is
not
None
if
residual
is
None
:
residual
=
hidden_states
.
clone
()
i_q
,
i_s
,
_
=
self
.
input_layernorm
(
x
=
hidden_states
,
residual
=
None
,
quant_dtype
=
torch
.
int8
,
update_input
=
False
update_input
=
_dsa_flag
)
residual_fix_overflow
=
True
else
:
i_q
,
i_s
,
residual
=
self
.
input_layernorm
(
x
=
hidden_states
,
residual
=
residual
,
quant_dtype
=
torch
.
int8
,
update_input
=
False
update_input
=
_dsa_flag
)
attn_kwargs
=
{
"positions"
:
positions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment