Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fc443d52
Commit
fc443d52
authored
Sep 14, 2025
by
wujl5
Committed by
zhuwenwen
Sep 14, 2025
Browse files
deepseek-r1-w4a8使用rmsquant融合算子及横向融合
parent
0627b53a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
335 additions
and
94 deletions
+335
-94
vllm/envs.py
vllm/envs.py
+5
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+146
-23
vllm/model_executor/layers/quantization/slimquant_w4a8.py
vllm/model_executor/layers/quantization/slimquant_w4a8.py
+7
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+177
-70
No files found.
vllm/envs.py
View file @
fc443d52
...
@@ -166,6 +166,7 @@ if TYPE_CHECKING:
...
@@ -166,6 +166,7 @@ if TYPE_CHECKING:
VLLM_USE_GLOBAL_CACHE13
:
bool
=
False
VLLM_USE_GLOBAL_CACHE13
:
bool
=
False
VLLM_USE_LIGHT_OP
:
bool
=
False
VLLM_USE_LIGHT_OP
:
bool
=
False
VLLM_USE_TRITON_CAT
:
bool
=
False
VLLM_USE_TRITON_CAT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
VLLM_USE_MERGE_ATTN_STATES_OPT
:
bool
=
False
VLLM_USE_MERGE_ATTN_STATES_OPT
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -1104,6 +1105,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1104,6 +1105,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MERGE_ATTN_STATES_OPT"
:
"VLLM_USE_MERGE_ATTN_STATES_OPT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MERGE_ATTN_STATES_OPT"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MERGE_ATTN_STATES_OPT"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT"
:
lambda
:
(
os
.
getenv
(
'USE_FUSED_RMS_QUANT'
,
'0'
).
lower
()
in
(
"true"
,
"1"
)),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/layers/linear.py
View file @
fc443d52
...
@@ -33,6 +33,12 @@ from vllm.platforms import current_platform
...
@@ -33,6 +33,12 @@ from vllm.platforms import current_platform
import
os
import
os
from
vllm.model_executor.utils
import
gemm_bank_conf
from
vllm.model_executor.utils
import
gemm_bank_conf
if
envs
.
USE_FUSED_RMS_QUANT
:
try
:
from
lmslim.quantize.quant_ops
import
lm_faster_rmsquant
except
Exception
as
e
:
print
(
f
"Error: Import fused rmsquant error:
{
e
}
"
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
WEIGHT_LOADER_V2_SUPPORTED
=
[
WEIGHT_LOADER_V2_SUPPORTED
=
[
...
@@ -327,6 +333,7 @@ class ReplicatedLinear(LinearBase):
...
@@ -327,6 +333,7 @@ class ReplicatedLinear(LinearBase):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
eps
:
Optional
[
float
]
=
1e-6
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
...
@@ -338,6 +345,7 @@ class ReplicatedLinear(LinearBase):
...
@@ -338,6 +345,7 @@ class ReplicatedLinear(LinearBase):
quant_config
,
quant_config
,
prefix
=
prefix
,
prefix
=
prefix
,
return_bias
=
return_bias
)
return_bias
=
return_bias
)
self
.
eps
=
eps
# All the linear layer supports quant method.
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
...
@@ -385,15 +393,53 @@ class ReplicatedLinear(LinearBase):
...
@@ -385,15 +393,53 @@ class ReplicatedLinear(LinearBase):
param
.
data
.
copy_
(
loaded_weight
)
param
.
data
.
copy_
(
loaded_weight
)
def
forward
(
def
forward
(
self
,
x
:
torch
.
Tensor
self
,
input_
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
quant_args
:
Optional
[
list
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
envs
.
USE_FUSED_RMS_QUANT
and
(
rms_weight
is
not
None
or
quant_args
is
not
None
):
assert
self
.
quant_method
is
not
None
if
quant_args
is
not
None
:
output
=
self
.
quant_method
.
apply
(
self
,
x
,
bias
)
input_quant_args
=
quant_args
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
return
output
assert
self
.
quant_method
is
not
None
return
output
,
output_bias
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
else
:
i_q
,
_scales
=
lm_faster_rmsquant
(
input
=
input_
,
rms_weight
=
rms_weight
,
epsilon
=
self
.
eps
,
quant_dtype
=
torch
.
int8
,
residual
=
residual
,
update_input
=
update_hd
)
new_residual
=
residual
input_quant_args
=
[
i_q
,
_scales
]
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
new_residual
,
output_bias
,
input_quant_args
else
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
=
f
"in_features=
{
self
.
input_size
}
"
...
@@ -436,6 +482,7 @@ class ColumnParallelLinear(LinearBase):
...
@@ -436,6 +482,7 @@ class ColumnParallelLinear(LinearBase):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
list
[
int
]]
=
None
,
output_sizes
:
Optional
[
list
[
int
]]
=
None
,
eps
:
Optional
[
float
]
=
1e-6
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
...
@@ -459,7 +506,7 @@ class ColumnParallelLinear(LinearBase):
...
@@ -459,7 +506,7 @@ class ColumnParallelLinear(LinearBase):
quant_config
,
quant_config
,
prefix
,
prefix
,
return_bias
=
return_bias
)
return_bias
=
return_bias
)
self
.
eps
=
eps
self
.
gather_output
=
gather_output
self
.
gather_output
=
gather_output
if
output_sizes
is
None
:
if
output_sizes
is
None
:
...
@@ -543,22 +590,49 @@ class ColumnParallelLinear(LinearBase):
...
@@ -543,22 +590,49 @@ class ColumnParallelLinear(LinearBase):
param
.
load_column_parallel_weight
(
loaded_weight
=
loaded_weight
)
param
.
load_column_parallel_weight
(
loaded_weight
=
loaded_weight
)
def
forward
(
def
forward
(
self
,
input_
self
,
input_
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
input_quant_args
=
None
assert
rms_weight
is
not
None
i_q
,
_scales
=
lm_faster_rmsquant
(
input
=
input_
,
rms_weight
=
rms_weight
,
epsilon
=
self
.
eps
,
quant_dtype
=
torch
.
int8
,
residual
=
residual
,
update_input
=
update_hd
)
new_residual
=
residual
input_quant_args
=
[
i_q
,
_scales
]
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
)
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
if
self
.
gather_output
:
if
self
.
gather_output
:
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
# All-gather across the partitions.
else
:
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
new_residual
,
output_bias
else
:
else
:
output
=
output_parallel
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
# Matrix multiply.
if
not
self
.
return_bias
:
assert
self
.
quant_method
is
not
None
return
output
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
return
output
,
output_bias
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
=
f
"in_features=
{
self
.
input_size
}
"
...
@@ -593,6 +667,54 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -593,6 +667,54 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
return_bias: If true, return bias together with outputs in forward pass.
return_bias: If true, return bias together with outputs in forward pass.
"""
"""
def
forward
(
self
,
input_
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
input_quant_args
=
None
assert
residual
is
not
None
and
rms_weight
is
not
None
i_q
,
_scales
=
lm_faster_rmsquant
(
input
=
input_
,
rms_weight
=
rms_weight
,
epsilon
=
self
.
eps
,
quant_dtype
=
torch
.
int8
,
residual
=
residual
,
update_input
=
update_hd
)
new_residual
=
residual
input_quant_args
=
[
i_q
,
_scales
]
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
new_residual
,
output_bias
else
:
# not USE_FUSED_RMS_QUANT
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
def
__init__
(
def
__init__
(
self
,
self
,
input_size
:
int
,
input_size
:
int
,
...
@@ -602,10 +724,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -602,10 +724,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
eps
:
Optional
[
float
]
=
1e-6
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
):
):
self
.
eps
=
eps
self
.
output_sizes
=
output_sizes
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
...
@@ -856,7 +980,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -856,7 +980,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset
=
shard_offset
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
shard_size
=
shard_size
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
class
QKVParallelLinear
(
ColumnParallelLinear
):
"""Linear layers for the attention's QKV transformation.
"""Linear layers for the attention's QKV transformation.
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8.py
View file @
fc443d52
...
@@ -21,6 +21,7 @@ from vllm.utils import W8a8GetCacheJSON
...
@@ -21,6 +21,7 @@ from vllm.utils import W8a8GetCacheJSON
import
os
import
os
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
...
@@ -153,8 +154,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -153,8 +154,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
):
):
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
if
envs
.
USE_FUSED_RMS_QUANT
and
input_quant_args
is
not
None
:
assert
len
(
input_quant_args
)
==
2
x_q
,
x_scale
=
input_quant_args
else
:
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
if
self
.
w8a8_strategy
==
1
:
if
self
.
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
m
=
x_q
.
shape
[
0
]
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
fc443d52
...
@@ -94,11 +94,21 @@ class DeepseekV2MLP(nn.Module):
...
@@ -94,11 +94,21 @@ class DeepseekV2MLP(nn.Module):
"Only silu is supported for now."
)
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
x
=
self
.
act_fn
(
gate_up
)
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
x
,
_
=
self
.
down_proj
(
x
)
update_hd
:
Optional
[
bool
]
=
False
return
x
):
if
envs
.
USE_FUSED_RMS_QUANT
:
gate_up
,
new_resi
,
_
=
self
.
gate_up_proj
(
x
,
rms_weight
,
residual
,
update_hd
=
update_hd
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
,
new_resi
else
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
DeepseekV2MoE
(
nn
.
Module
):
class
DeepseekV2MoE
(
nn
.
Module
):
...
@@ -185,11 +195,17 @@ class DeepseekV2MoE(nn.Module):
...
@@ -185,11 +195,17 @@ class DeepseekV2MoE(nn.Module):
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
...
@@ -219,8 +235,10 @@ class DeepseekV2MoE(nn.Module):
...
@@ -219,8 +235,10 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
(
final_hidden_states
=
(
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
final_hidden_states
))
final_hidden_states
))
if
envs
.
USE_FUSED_RMS_QUANT
:
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
),
new_resi
else
:
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
...
@@ -421,19 +439,36 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -421,19 +439,36 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
q_lora_rank
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
if
envs
.
USE_FUSED_RMS_QUANT
:
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
q_lora_rank
,
self
.
q_lora_rank
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
eps
=
config
.
rms_norm_eps
,
prefix
=
f
"
{
prefix
}
.q_a_proj"
)
prefix
=
f
"
{
prefix
}
.q_a_proj"
)
self
.
q_a_layernorm
=
RMSNorm
(
self
.
q_lora_rank
,
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
num_heads
*
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
eps
=
config
.
rms_norm_eps
,
prefix
=
f
"
{
prefix
}
.q_b_proj"
)
else
:
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
q_lora_rank
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_a_proj"
)
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
self
.
num_heads
*
self
.
num_heads
*
self
.
qk_head_dim
,
self
.
qk_head_dim
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_b_proj"
)
prefix
=
f
"
{
prefix
}
.q_b_proj"
)
self
.
q_a_layernorm
=
RMSNorm
(
self
.
q_lora_rank
,
eps
=
config
.
rms_norm_eps
)
else
:
else
:
self
.
q_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
q_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
num_heads
*
...
@@ -508,31 +543,60 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -508,31 +543,60 @@ class DeepseekV2MLAAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
q_lora_rank
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
if
self
.
q_lora_rank
is
not
None
:
q_c
=
self
.
q_a_layernorm
(
q_c
)
q_c
,
new_residual
,
_
,
input_quant_args
=
self
.
q_a_proj
(
hidden_states
,
rms_weight
=
rms_weight
,
residual
=
residual
,
update_hd
=
False
)
q
=
self
.
q_b_proj
(
q_c
)[
0
]
q
,
_
,
_
=
self
.
q_b_proj
(
q_c
,
rms_weight
=
self
.
q_a_layernorm
.
weight
.
data
,
residual
=
None
,
update_hd
=
False
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
,
quant_args
=
input_quant_args
,
update_hd
=
False
)[
0
].
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
],
new_residual
else
:
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
]
if
self
.
q_lora_rank
is
not
None
:
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
q_c
=
self
.
q_a_layernorm
(
q_c
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
q
=
self
.
q_b_proj
(
q_c
)[
0
]
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
# Add head dim of 1 to k_pe
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
attn_out
=
self
.
mla_attn
(
attn_out
=
self
.
mla_attn
(
q
,
q
,
kv_c_normed
,
kv_c_normed
,
k_pe
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
self
.
num_local_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
]
return
self
.
o_proj
(
attn_out
)[
0
]
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
...
@@ -607,47 +671,90 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -607,47 +671,90 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Self Attention
if
envs
.
USE_FUSED_RMS_QUANT
:
# Fix residual FP16 overflow
# Fix residual FP16 overflow
residual_fix_overflow
=
False
residual_fix_overflow
=
False
if
residual
is
None
:
assert
self
.
input_layernorm
.
has_weight
is
True
residual
=
hidden_states
if
residual
is
None
:
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
residual
=
hidden_states
residual_fix_overflow
=
True
hidden_states
,
_
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
rms_weight
=
self
.
input_layernorm
.
weight
.
data
,
residual
=
None
)
residual_fix_overflow
=
True
else
:
hidden_states
,
new_residual
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
rms_weight
=
self
.
input_layernorm
.
weight
.
data
,
residual
=
residual
)
residual
=
new_residual
if
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
hidden_states
,
new_resi
=
self
.
mlp
(
hidden_states
,
self
.
post_attention_layernorm
.
weight
.
data
,
residual
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
new_resi
else
:
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow
=
False
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
residual_fix_overflow
=
True
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
if
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
mlp
(
hidden_states
)
positions
=
positions
,
hidden_states
=
hidden_states
,
)
if
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
if
isinstance
(
self
.
mlp
,
# Fix FP16 overflow
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
# We scale both hidden_states and residual before
# Fix FP16 overflow
# rmsnorm, and rmsnorm result would not affect by scale.
# Scaling the DeepseekV2MLP output, it is the input of
hidden_states
*=
1.
/
self
.
routed_scaling_factor
# input_layernorm of next decoder layer.
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
# The scaling of DeepseekV2MOE output would be done in the forward
# The residual is shared by all layers, we only scale it on
# of DeepseekV2MOE
# first layer.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
residual
*=
1.
/
self
.
routed_scaling_factor
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
and
not
self
.
dpsk_fp16_quick
:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
residual
return
hidden_states
,
residual
@
support_torch_compile
@
support_torch_compile
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment