Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
89b62a25
"vllm/vscode:/vscode.git/clone" did not exist on "948dd3443bc6b8ffb76cbdddf3f4c5ae0b6637fa"
Commit
89b62a25
authored
Jan 07, 2026
by
wujl5
Browse files
DS量化模型重构atten和moe调用rmsquant融合逻辑。
parent
fb39e61b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
84 additions
and
137 deletions
+84
-137
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
+1
-1
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+1
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+21
-82
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+61
-53
No files found.
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
View file @
89b62a25
...
@@ -34,7 +34,7 @@ class SharedFusedMoE(FusedMoE):
...
@@ -34,7 +34,7 @@ class SharedFusedMoE(FusedMoE):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
hidden_states_copy
:
Optional
[
torch
.
Tensor
]
=
None
hidden_states_copy
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
torch
.
Tensor
:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
torch
.
Tensor
:
if
not
self
.
use_overlapped
:
if
not
self
.
use_overlapped
:
shared_out
=
self
.
_shared_experts
(
hidden_states
)
shared_out
=
self
.
_shared_experts
(
hidden_states
)
...
...
vllm/model_executor/layers/fused_moe/utils.py
View file @
89b62a25
...
@@ -897,7 +897,7 @@ class EPSharedExperts(nn.Module):
...
@@ -897,7 +897,7 @@ class EPSharedExperts(nn.Module):
"Only silu is supported for now."
)
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
**
_
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
...
...
vllm/model_executor/layers/linear.py
View file @
89b62a25
...
@@ -331,7 +331,6 @@ class ReplicatedLinear(LinearBase):
...
@@ -331,7 +331,6 @@ class ReplicatedLinear(LinearBase):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
eps
:
Optional
[
float
]
=
1e-6
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
...
@@ -343,7 +342,6 @@ class ReplicatedLinear(LinearBase):
...
@@ -343,7 +342,6 @@ class ReplicatedLinear(LinearBase):
quant_config
,
quant_config
,
prefix
=
prefix
,
prefix
=
prefix
,
return_bias
=
return_bias
)
return_bias
=
return_bias
)
self
.
eps
=
eps
# All the linear layer supports quant method.
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
...
@@ -393,45 +391,19 @@ class ReplicatedLinear(LinearBase):
...
@@ -393,45 +391,19 @@ class ReplicatedLinear(LinearBase):
def
forward
(
def
forward
(
self
,
self
,
input_
:
torch
.
Tensor
,
input_
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
iqis
:
Optional
[
tuple
]
=
None
,
**
_
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
quant_args
:
Optional
[
list
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]],
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
Parameter
],
list
[
torch
.
Tensor
]]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
(
rms_weight
is
not
None
or
quant_args
is
not
None
):
if
quant_args
is
not
None
:
input_quant_args
=
quant_args
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
iqis
is
not
None
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
)
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
=
iqis
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
if
not
self
.
return_bias
:
return
output
return
output
return
output
,
output_bias
return
output
,
output_bias
else
:
i_q
,
_scales
=
lm_faster_rmsquant
(
input
=
input_
,
rms_weight
=
rms_weight
,
epsilon
=
self
.
eps
,
quant_dtype
=
torch
.
int8
,
residual
=
residual
,
update_input
=
update_hd
)
new_residual
=
residual
input_quant_args
=
[
i_q
,
_scales
]
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
new_residual
,
output_bias
,
input_quant_args
else
:
else
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
...
@@ -459,7 +431,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
...
@@ -459,7 +431,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
eps
:
Optional
[
float
]
=
1e-6
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
...
@@ -473,7 +444,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
...
@@ -473,7 +444,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
quant_config
,
quant_config
,
prefix
=
prefix
,
prefix
=
prefix
,
return_bias
=
return_bias
)
return_bias
=
return_bias
)
self
.
eps
=
eps
self
.
q_lora_rank
=
q_lora_rank
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
...
@@ -588,7 +558,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
...
@@ -588,7 +558,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
assert
len
(
self
.
kv_a_weight
.
shape
)
==
2
assert
len
(
self
.
kv_a_weight
.
shape
)
==
2
fused_weight
=
torch
.
cat
([
self
.
q_a_weight
,
self
.
kv_a_weight
],
dim
=
0
)
# TN
fused_weight
=
torch
.
cat
([
self
.
q_a_weight
,
self
.
kv_a_weight
],
dim
=
0
)
# TN
param
.
data
.
copy_
(
fused_weight
)
param
.
data
.
copy_
(
fused_weight
)
#TODO: wjl 删掉无用的显存tensor
else
:
else
:
raise
ValueError
(
f
"Unexpected weight:
{
source
}
"
)
raise
ValueError
(
f
"Unexpected weight:
{
source
}
"
)
...
@@ -596,31 +565,17 @@ class FusedQuantedReplicatedLinear(LinearBase):
...
@@ -596,31 +565,17 @@ class FusedQuantedReplicatedLinear(LinearBase):
def
forward
(
def
forward
(
self
,
self
,
input_
:
torch
.
Tensor
,
input_
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
iqis
:
Optional
[
tuple
]
=
None
,
**
_
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]:
update_hd
:
Optional
[
bool
]
=
True
if
envs
.
USE_FUSED_RMS_QUANT
and
iqis
is
not
None
:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
Parameter
],
list
[
torch
.
Tensor
]]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
i_q
,
_scales
=
lm_faster_rmsquant
(
input
=
input_
,
rms_weight
=
rms_weight
,
epsilon
=
self
.
eps
,
quant_dtype
=
torch
.
int8
,
residual
=
residual
,
update_input
=
update_hd
)
new_residual
=
residual
input_quant_args
=
[
i_q
,
_scales
]
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
)
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
=
iqis
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
assert
self
.
return_bias
is
True
assert
self
.
return_bias
is
True
if
not
self
.
return_bias
:
if
not
self
.
return_bias
:
raise
RuntimeError
(
"Not return bias. Unexpected Error."
)
raise
RuntimeError
(
"Not return bias. Unexpected Error."
)
return
output
,
new_residual
,
output_bias
return
output
,
output_bias
else
:
else
:
raise
RuntimeError
(
"Unexpected Error."
)
raise
RuntimeError
(
"Unexpected Error."
)
...
@@ -858,31 +813,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -858,31 +813,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
def
forward
(
def
forward
(
self
,
input_
,
self
,
input_
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
xqxs
:
Optional
[
tuple
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
iqis
:
Optional
[
tuple
]
=
None
,
**
_
update_hd
:
Optional
[
bool
]
=
True
,
xqxs
:
Optional
[
tuple
]
=
None
)
->
Union
[
torch
.
Tensor
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]],
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]],
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
Parameter
]],
]:
]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
and
iqis
is
not
None
:
input_quant_args
=
None
assert
residual
is
not
None
and
rms_weight
is
not
None
i_q
,
_scales
=
lm_faster_rmsquant
(
input
=
input_
,
rms_weight
=
rms_weight
,
epsilon
=
self
.
eps
,
quant_dtype
=
torch
.
int8
,
residual
=
residual
,
update_input
=
update_hd
)
new_residual
=
residual
input_quant_args
=
[
i_q
,
_scales
]
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
)
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
=
iqis
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
...
@@ -892,7 +833,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -892,7 +833,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
if
not
self
.
return_bias
:
return
output
return
output
return
output
,
new_residual
,
i_q
,
_scales
,
output_bias
return
output
,
output_bias
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
...
@@ -933,13 +874,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -933,13 +874,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
eps
:
Optional
[
float
]
=
1e-6
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
expect_tp_size
:
Optional
[
int
]
=
None
,
expect_tp_size
:
Optional
[
int
]
=
None
,
):
):
self
.
eps
=
eps
self
.
output_sizes
=
output_sizes
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
89b62a25
...
@@ -70,6 +70,8 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
...
@@ -70,6 +70,8 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
maybe_prefix
)
maybe_prefix
)
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
from
lmslim.quantize.quant_ops
import
lm_faster_rmsquant
class
DeepseekV2MLP
(
nn
.
Module
):
class
DeepseekV2MLP
(
nn
.
Module
):
...
@@ -100,21 +102,18 @@ class DeepseekV2MLP(nn.Module):
...
@@ -100,21 +102,18 @@ class DeepseekV2MLP(nn.Module):
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
def
forward
(
self
,
x
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
iqis
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
update_hd
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]:
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
gate_up
,
new_resi
,
i_q
,
_scales
,
_
=
self
.
gate_up_proj
(
x
,
rms_weight
,
residual
,
update_hd
=
update_hd
)
assert
iqis
is
not
None
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
iqis
=
iqis
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
x
,
_
=
self
.
down_proj
(
gate_up
,
use_fused_silu_mul_quant
=
True
)
x
,
_
=
self
.
down_proj
(
gate_up
,
use_fused_silu_mul_quant
=
True
)
else
:
else
:
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
return
x
,
new_resi
,
i_q
,
_scales
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
xqxs
=
xqxs
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
xqxs
=
xqxs
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
...
@@ -279,9 +278,9 @@ class DeepseekV2MoE(nn.Module):
...
@@ -279,9 +278,9 @@ class DeepseekV2MoE(nn.Module):
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
iqis
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
)
->
Union
[
torch
.
Tensor
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
...
@@ -338,12 +337,12 @@ class DeepseekV2MoE(nn.Module):
...
@@ -338,12 +337,12 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
*
(
1.
/
self
.
routed_scaling_factor
)
else
:
else
:
# RQ
if
not
self
.
enable_expert_parallel
:
if
not
self
.
enable_expert_parallel
:
i_q
,
i_s
=
None
,
None
i_q
,
i_s
=
None
,
None
if
self
.
run_shared_expert_singlely
:
if
self
.
run_shared_expert_singlely
:
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
,
i_q
,
i_s
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
shared_output
=
self
.
shared_experts
(
hidden_states
,
iqis
=
iqis
)
else
:
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
...
@@ -378,7 +377,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -378,7 +377,8 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
# See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
# fp16 mode not fused quant
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
router_logits
=
router_logits
,
i_q
=
iqis
[
0
],
i_s
=
iqis
[
1
])
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
...
@@ -388,7 +388,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -388,7 +388,7 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
*
(
1.
/
self
.
routed_scaling_factor
)
else
:
else
:
# EP
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
self
.
use_deepep
:
if
self
.
use_deepep
:
shared_output
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
shared_output
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
...
@@ -405,7 +405,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -405,7 +405,7 @@ class DeepseekV2MoE(nn.Module):
else
:
else
:
if
self
.
run_shared_expert_singlely
:
if
self
.
run_shared_expert_singlely
:
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
shared_output
=
self
.
shared_experts
(
hidden_states
,
iqis
=
iqis
)
else
:
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
...
@@ -420,9 +420,9 @@ class DeepseekV2MoE(nn.Module):
...
@@ -420,9 +420,9 @@ class DeepseekV2MoE(nn.Module):
assert
shared_output
is
not
None
assert
shared_output
is
not
None
final_hidden_states
+=
(
shared_output
*
(
1.
/
self
.
routed_scaling_factor
))
final_hidden_states
+=
(
shared_output
*
(
1.
/
self
.
routed_scaling_factor
))
else
:
else
:
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_state
s
,
router_logits
=
router_logit
s
,
router_logits
=
router_logits
)
i_q
=
iqis
[
0
],
i_s
=
iqis
[
1
]
)
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
...
@@ -441,9 +441,6 @@ class DeepseekV2MoE(nn.Module):
...
@@ -441,9 +441,6 @@ class DeepseekV2MoE(nn.Module):
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
final_hidden_states
))
final_hidden_states
))
if
envs
.
USE_FUSED_RMS_QUANT
:
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
),
new_resi
,
i_q
,
i_s
else
:
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
...
@@ -662,7 +659,6 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -662,7 +659,6 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
q_lora_rank
,
self
.
q_lora_rank
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
eps
=
config
.
rms_norm_eps
,
prefix
=
f
"
{
prefix
}
.q_a_proj"
)
prefix
=
f
"
{
prefix
}
.q_a_proj"
)
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
self
.
num_heads
*
self
.
num_heads
*
...
@@ -764,20 +760,22 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -764,20 +760,22 @@ class DeepseekV2MLAAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_residual
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_residual
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_rms_eps
:
Optional
[
float
]
=
1e-6
,
pa_rms_eps
:
Optional
[
float
]
=
1e-6
,
pa_quant_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
int8
,
pa_quant_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
int8
,
update_input
:
Optional
[
bool
]
=
True
update_input
:
Optional
[
bool
]
=
True
,
iqis
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
)
->
Union
[
torch
.
Tensor
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
and
iqis
is
not
None
:
if
envs
.
VLLM_USE_FUSED_QA_KVA_GEMM
:
if
envs
.
VLLM_USE_FUSED_QA_KVA_GEMM
:
if
self
.
q_lora_rank
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
qc_kvc_kpe
,
new_residual
,
_bias
=
self
.
qa_kva_proj
(
hidden_states
,
rms_weight
=
rms_weight
,
residual
=
residual
,
update_hd
=
False
)
# rms_weight=rms_weight, residual=residual, update_hd=False
qc_kvc_kpe
,
_bias
=
self
.
qa_kva_proj
(
hidden_states
,
iqis
)
q_c
=
qc_kvc_kpe
[:,
:
self
.
q_lora_rank
]
q_c
=
qc_kvc_kpe
[:,
:
self
.
q_lora_rank
]
kvc_kpe
=
qc_kvc_kpe
[:,
self
.
q_lora_rank
:]
kvc_kpe
=
qc_kvc_kpe
[:,
self
.
q_lora_rank
:]
q
,
_
,
_
=
self
.
q_b_proj
(
q_c
,
rms_weight
=
self
.
q_a_layernorm
.
weight
.
data
,
residual
=
None
,
update_hd
=
False
)
q
,
_
,
_
=
self
.
q_b_proj
(
q_c
,
rms_weight
=
self
.
q_a_layernorm
.
weight
.
data
,
residual
=
None
,
update_hd
=
False
)
...
@@ -787,12 +785,12 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -787,12 +785,12 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c
,
k_pe
=
kvc_kpe
.
split
([
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_c
,
k_pe
=
kvc_kpe
.
split
([
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
else
:
else
:
if
self
.
q_lora_rank
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
q_c
,
new_residual
,
_
,
input_quant_args
=
self
.
q_a_proj
(
hidden_states
,
rms_weight
=
rms_weight
,
residual
=
residual
,
update_hd
=
False
)
q_c
,
_
=
self
.
q_a_proj
(
hidden_states
,
iqis
=
iqis
)
q
,
_
,
_
=
self
.
q_b_proj
(
q_c
,
rms_weight
=
self
.
q_a_layernorm
.
weight
.
data
,
residual
=
None
,
update_hd
=
False
)
q
,
_
,
_
=
self
.
q_b_proj
(
q_c
,
rms_weight
=
self
.
q_a_layernorm
.
weight
.
data
,
residual
=
None
,
update_hd
=
False
)
else
:
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
]
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kvc_kpe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
,
quant_args
=
input_quant_args
,
update_hd
=
False
)[
0
]
kvc_kpe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
,
iqis
=
iqis
)[
0
]
kv_c
,
k_pe
=
kvc_kpe
.
split
(
kv_c
,
k_pe
=
kvc_kpe
.
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
...
@@ -835,7 +833,7 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -835,7 +833,7 @@ class DeepseekV2MLAAttention(nn.Module):
positions
=
positions
,
positions
=
positions
,
weight
=
weight
,
weight
=
weight
,
cos_sin_cache
=
cos_sin_cache
)
cos_sin_cache
=
cos_sin_cache
)
return
self
.
o_proj
(
attn_out
)[
0
]
,
new_residual
return
self
.
o_proj
(
attn_out
)[
0
]
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
pa_rms_weight
is
not
None
and
pa_residual
is
not
None
:
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
pa_rms_weight
is
not
None
and
pa_residual
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
...
@@ -1035,10 +1033,11 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1035,10 +1033,11 @@ class DeepseekV2DecoderLayer(nn.Module):
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
use_fused_rms_quant
=
envs
.
USE_FUSED_RMS_QUANT
self
.
use_fused_rms_quant
=
envs
.
USE_FUSED_RMS_QUANT
self
.
use_fused_custom_all_reduce
=
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
self
.
use_fused_custom_all_reduce
=
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
self
.
_eps
=
config
.
rms_norm_eps
def
forward_fused_
rmsquant
(
def
forward_fused_
RQ
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -1050,21 +1049,24 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1050,21 +1049,24 @@ class DeepseekV2DecoderLayer(nn.Module):
assert
self
.
input_layernorm
.
has_weight
is
True
assert
self
.
input_layernorm
.
has_weight
is
True
if
residual
is
None
:
if
residual
is
None
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
,
_
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
rms_weight
=
self
.
input_layernorm
.
weight
.
data
,
residual
=
None
)
residual_fix_overflow
=
True
residual_fix_overflow
=
True
i_q
,
i_s
=
lm_faster_rmsquant
(
input
=
hidden_states
,
rms_weight
=
self
.
input_layernorm
.
weight
.
data
,
epsilon
=
self
.
_eps
,
quant_dtype
=
torch
.
int8
,
residual
=
None
,
update_input
=
False
)
else
:
else
:
hidden_states
,
new_residual
=
self
.
self_attn
(
i_q
,
i_s
=
lm_faster_rmsquant
(
input
=
hidden_states
,
positions
=
positions
,
hidden_states
=
hidden_states
,
rms_weight
=
self
.
input_layernorm
.
weight
.
data
,
rms_weight
=
self
.
input_layernorm
.
weight
.
data
,
residual
=
residual
epsilon
=
self
.
_eps
,
)
quant_dtype
=
torch
.
int8
,
residual
=
new_residual
residual
=
residual
,
update_input
=
False
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
# get attr
iqis
=
(
i_q
,
i_s
))
if
hidden_states
.
dtype
==
torch
.
float16
:
if
hidden_states
.
dtype
==
torch
.
float16
:
# rmsnorm, and rmsnorm result would not affect by scale.
# rmsnorm, and rmsnorm result would not affect by scale.
...
@@ -1074,10 +1076,16 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1074,10 +1076,16 @@ class DeepseekV2DecoderLayer(nn.Module):
# first layer.
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
residual
*=
1.
/
self
.
routed_scaling_factor
hidden_states
,
new_resi
,
_i_q
,
_scales
=
self
.
mlp
(
hidden_states
,
update_hs
=
True
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
else
False
_i_q
,
_i_s
=
lm_faster_rmsquant
(
input
=
hidden_states
,
rms_weight
=
self
.
post_attention_layernorm
.
weight
.
data
,
rms_weight
=
self
.
post_attention_layernorm
.
weight
.
data
,
epsilon
=
self
.
_eps
,
quant_dtype
=
torch
.
int8
,
residual
=
residual
,
residual
=
residual
,
)
update_input
=
update_hs
)
new_resi
=
residual
hidden_states
=
self
.
mlp
(
hidden_states
,
iqis
=
(
_i_q
,
_i_s
))
if
isinstance
(
self
.
mlp
,
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
...
@@ -1211,7 +1219,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1211,7 +1219,7 @@ class DeepseekV2DecoderLayer(nn.Module):
def
choose_forward
(
self
):
def
choose_forward
(
self
):
if
self
.
use_fused_rms_quant
:
if
self
.
use_fused_rms_quant
:
return
self
.
forward_fused_
rmsquant
return
self
.
forward_fused_
RQ
elif
self
.
use_fused_custom_all_reduce
:
elif
self
.
use_fused_custom_all_reduce
:
return
self
.
forward_fused_CRQ
return
self
.
forward_fused_CRQ
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment