Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3f5c2eea
Commit
3f5c2eea
authored
Nov 19, 2025
by
zhuwenwen
Browse files
add mla tpsp and moe share experts computation communication overlap
parent
8375370f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
473 additions
and
72 deletions
+473
-72
vllm/envs.py
vllm/envs.py
+25
-2
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+7
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+88
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+353
-68
No files found.
vllm/envs.py
View file @
3f5c2eea
...
@@ -178,6 +178,8 @@ if TYPE_CHECKING:
...
@@ -178,6 +178,8 @@ if TYPE_CHECKING:
VLLM_P2P_BUF_TOKENS
:
int
=
30000
VLLM_P2P_BUF_TOKENS
:
int
=
30000
VLLM_SCHED_ENABLE_MINIMAL_INJECTION
:
bool
=
False
VLLM_SCHED_ENABLE_MINIMAL_INJECTION
:
bool
=
False
VLLM_USE_PD_SPLIT
:
bool
=
False
VLLM_USE_PD_SPLIT
:
bool
=
False
VLLM_ENABLE_MLA_SP
:
bool
=
False
VLLM_ENABLE_MLA_QKV_MERGE
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
...
@@ -1094,68 +1096,89 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1094,68 +1096,89 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_PA"
:
"VLLM_USE_FLASH_ATTN_PA"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_PA"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_PA"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use apex for rmsnorm
# vLLM will use apex for rmsnorm
"VLLM_USE_APEX_RN"
:
"VLLM_USE_APEX_RN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_APEX_RN"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_APEX_RN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use global cache for moe
# vLLM will use global cache for moe
"VLLM_USE_GLOBAL_CACHE13"
:
"VLLM_USE_GLOBAL_CACHE13"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_GLOBAL_CACHE13"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_GLOBAL_CACHE13"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use lightop for deepseek-v3
# vLLM will use lightop for deepseek-v3
"VLLM_USE_LIGHTOP"
:
"VLLM_USE_LIGHTOP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use elenmentwise not triton_
# vLLM will use elenmentwise not triton_
"VLLM_USE_OPT_ZEROS"
:
"VLLM_USE_OPT_ZEROS"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_ZEROS"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_ZEROS"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use opt cat for deepseek-v3
# vLLM will use opt cat for deepseek-v3
"VLLM_USE_OPT_CAT"
:
"VLLM_USE_OPT_CAT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_CAT"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_CAT"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use triton moe_sum
# vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM"
:
"VLLM_USE_OPT_MOE_SUM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_MOE_SUM"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_MOE_SUM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use lightop moe_sum_mul_add
# vLLM will use lightop moe_sum_mul_add
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"
:
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use lightop moe_sum
# vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM"
:
"VLLM_USE_LIGHTOP_MOE_SUM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP_MOE_SUM"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP_MOE_SUM"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use lightop moe_align_block_size
# vLLM will use lightop moe_align_block_size
"VLLM_USE_LIGHTOP_MOE_ALIGN"
:
"VLLM_USE_LIGHTOP_MOE_ALIGN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP_MOE_ALIGN"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP_MOE_ALIGN"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use opt merge_aatn_states, not triton
# vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT"
:
"VLLM_USE_MERGE_ATTN_STATES_OPT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MERGE_ATTN_STATES_OPT"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MERGE_ATTN_STATES_OPT"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vllm will use rmsquant fused op
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT"
:
"USE_FUSED_RMS_QUANT"
:
lambda
:
(
os
.
getenv
(
'USE_FUSED_RMS_QUANT'
,
'0'
).
lower
()
in
lambda
:
(
os
.
getenv
(
'USE_FUSED_RMS_QUANT'
,
'0'
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vllm will use silu_mul_quant fused op
# vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT"
:
"USE_FUSED_SILU_MUL_QUANT"
:
lambda
:
(
os
.
getenv
(
'USE_FUSED_SILU_MUL_QUANT'
,
'0'
).
lower
()
in
lambda
:
(
os
.
getenv
(
'USE_FUSED_SILU_MUL_QUANT'
,
'0'
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vllm pd separation will be used async
# vllm pd separation will be used async
"VLLM_P2P_ASYNC"
:
"VLLM_P2P_ASYNC"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_P2P_ASYNC"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_P2P_ASYNC"
,
"0"
))),
# pd separation p2p async buf tokens
# pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS"
:
"VLLM_P2P_BUF_TOKENS"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_P2P_BUF_TOKENS"
,
"30000"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_P2P_BUF_TOKENS"
,
"30000"
)),
# vllm will enable minimal injection for pipeline parallel scheduling
# vllm will enable minimal injection for pipeline parallel scheduling
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
:
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
:
lambda
:
(
os
.
getenv
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
,
"0"
).
lower
()
in
lambda
:
(
os
.
getenv
(
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION"
,
"0"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will split prefill and decode, not mix up
# vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT"
:
"VLLM_USE_PD_SPLIT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_PD_SPLIT"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_PD_SPLIT"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
"VLLM_ENABLE_MLA_SP"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_MLA_SP"
,
"0"
))),
"VLLM_ENABLE_MLA_QKV_MERGE"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_MLA_QKV_MERGE"
,
"0"
))),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
3f5c2eea
...
@@ -637,6 +637,13 @@ def determine_expert_map(
...
@@ -637,6 +637,13 @@ def determine_expert_map(
return
(
local_num_experts
,
expert_map
)
return
(
local_num_experts
,
expert_map
)
EventType
=
Enum
(
'EventType'
,
[
'Main'
,
'Attention'
,
'QCAllgather'
,
'KVFinish'
,
'MoeShared'
,
'MoeChunkingOverlap'
,
'MoeAllgather'
,
'MoeReduceScatter'
],
start
=
0
,
)
class
FusedMoE
(
torch
.
nn
.
Module
):
class
FusedMoE
(
torch
.
nn
.
Module
):
"""FusedMoE layer for MoE models.
"""FusedMoE layer for MoE models.
...
...
vllm/model_executor/layers/linear.py
View file @
3f5c2eea
...
@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
,
tensor_model_parallel_reduce_scatter
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
...
@@ -454,6 +455,86 @@ class ReplicatedLinear(LinearBase):
...
@@ -454,6 +455,86 @@ class ReplicatedLinear(LinearBase):
return
s
return
s
class
MergedReplicatedLinear
(
ReplicatedLinear
):
"""Merged replicated linear layer
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def
__init__
(
self
,
input_size
:
int
,
output_sizes
:
list
[
int
],
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
):
self
.
output_sizes
=
output_sizes
super
().
__init__
(
input_size
,
sum
(
output_sizes
),
bias
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
=
prefix
,
return_bias
=
return_bias
)
def
weight_loader
(
self
,
param
:
Union
[
Parameter
,
BasevLLMParameter
],
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
int
]
=
None
):
assert
loaded_shard_id
is
not
None
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
if
isinstance
(
param
,
BlockQuantScaleParameter
):
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
Fp8MoEMethod
)
assert
self
.
quant_method
is
not
None
assert
isinstance
(
self
.
quant_method
,
(
Fp8LinearMethod
,
Fp8MoEMethod
))
weight_block_size
=
self
.
quant_method
.
quant_config
.
weight_block_size
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
(
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
+
block_n
-
1
)
//
block_n
)
shard_size
=
(
(
self
.
output_sizes
[
loaded_shard_id
]
+
block_n
-
1
)
//
block_n
)
elif
isinstance
(
param
,
PerTensorScaleParameter
)
and
current_platform
.
is_rocm
():
shard_offset
=
loaded_shard_id
shard_size
=
1
else
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
start_offset
=
shard_offset
end_offset
=
start_offset
+
shard_size
assert
loaded_weight
.
shape
==
param
.
data
[
start_offset
:
end_offset
,
...].
shape
,
(
f
"Expected shape
{
param
.
data
[
start_offset
:
end_offset
,
...].
shape
}
, got
{
loaded_weight
.
shape
}
"
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
param
.
data
[
start_offset
:
end_offset
,
...].
copy_
(
loaded_weight
)
class
ColumnParallelLinear
(
LinearBase
):
class
ColumnParallelLinear
(
LinearBase
):
"""Linear layer with column parallelism.
"""Linear layer with column parallelism.
...
@@ -1390,6 +1471,7 @@ class RowParallelLinear(LinearBase):
...
@@ -1390,6 +1471,7 @@ class RowParallelLinear(LinearBase):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
sp_parallel
:
bool
=
False
,
):
):
# Divide the weight matrix along the first dimension.
# Divide the weight matrix along the first dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
...
@@ -1397,6 +1479,7 @@ class RowParallelLinear(LinearBase):
...
@@ -1397,6 +1479,7 @@ class RowParallelLinear(LinearBase):
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
output_size_per_partition
=
output_size
self
.
output_size_per_partition
=
output_size
self
.
output_partition_sizes
=
[
output_size
]
self
.
output_partition_sizes
=
[
output_size
]
self
.
sp_parallel
=
sp_parallel
super
().
__init__
(
input_size
,
super
().
__init__
(
input_size
,
output_size
,
output_size
,
...
@@ -1526,7 +1609,10 @@ class RowParallelLinear(LinearBase):
...
@@ -1526,7 +1609,10 @@ class RowParallelLinear(LinearBase):
if
envs
.
VLLM_ENABLE_TBO
:
if
envs
.
VLLM_ENABLE_TBO
:
output
=
self
.
tbo_all_reduce
(
output_parallel
)
output
=
self
.
tbo_all_reduce
(
output_parallel
)
else
:
else
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
if
self
.
sp_parallel
:
output
=
tensor_model_parallel_reduce_scatter
(
output_parallel
.
contiguous
(),
dim
=
0
)
else
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
=
output_parallel
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
3f5c2eea
...
@@ -29,10 +29,11 @@ import vllm.envs as envs
...
@@ -29,10 +29,11 @@ import vllm.envs as envs
import
typing
import
typing
from
collections.abc
import
Callable
,
Iterable
from
collections.abc
import
Callable
,
Iterable
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
,
Dict
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
from
vllm.attention
import
Attention
...
@@ -40,12 +41,17 @@ from vllm.compilation.decorators import support_torch_compile
...
@@ -40,12 +41,17 @@ from vllm.compilation.decorators import support_torch_compile
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
VllmConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
VllmConfig
,
get_current_vllm_config
)
get_current_vllm_config
)
from
vllm.distributed
import
(
get_ep_group
,
get_pp_group
,
get_dp_group
,
from
vllm.distributed
import
(
get_ep_group
,
get_pp_group
,
get_dp_group
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_reduce_scatter
,
get_tensor_model_parallel_rank
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.layer
import
EventType
,
AuxStreamType
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
MergedReplicatedLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
...
@@ -64,6 +70,9 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
...
@@ -64,6 +70,9 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
maybe_prefix
)
maybe_prefix
)
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
DeepseekV2MLP
(
nn
.
Module
):
class
DeepseekV2MLP
(
nn
.
Module
):
...
@@ -75,6 +84,7 @@ class DeepseekV2MLP(nn.Module):
...
@@ -75,6 +84,7 @@ class DeepseekV2MLP(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
enable_tpsp
:
bool
=
False
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
...
@@ -82,12 +92,14 @@ class DeepseekV2MLP(nn.Module):
...
@@ -82,12 +92,14 @@ class DeepseekV2MLP(nn.Module):
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
self
.
enable_tpsp
=
enable_tpsp
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
reduce_results
=
reduce_results
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
prefix
=
f
"
{
prefix
}
.down_proj"
,
sp_parallel
=
self
.
enable_tpsp
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
...
@@ -108,12 +120,57 @@ class DeepseekV2MLP(nn.Module):
...
@@ -108,12 +120,57 @@ class DeepseekV2MLP(nn.Module):
return
x
,
new_resi
return
x
,
new_resi
else
:
else
:
if
self
.
enable_tpsp
:
x
=
tensor_model_parallel_all_gather
(
x
.
contiguous
(),
dim
=
0
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
return
x
class
SharedExpertOverlapSPMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
prefix
:
str
=
""
,
event_dict
:
dict
=
None
,
aux_stream
:
torch
.
cuda
.
Stream
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedReplicatedLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
self
.
down_proj
=
ReplicatedLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
self
.
event_dict
=
event_dict
self
.
aux_stream
=
aux_stream
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
self
.
event_dict
[
EventType
.
MoeAllgather
].
wait
()
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
self
.
event_dict
[
EventType
.
MoeReduceScatter
].
wait
()
x
,
_
=
self
.
down_proj
(
x
)
self
.
event_dict
[
EventType
.
MoeShared
].
record
()
return
x
class
DeepseekV2MoE
(
nn
.
Module
):
class
DeepseekV2MoE
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -121,7 +178,9 @@ class DeepseekV2MoE(nn.Module):
...
@@ -121,7 +178,9 @@ class DeepseekV2MoE(nn.Module):
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
trt_aux_stream_dict
:
Dict
[
AuxStreamType
,
torch
.
cuda
.
Stream
]
=
{},
enable_eplb
:
bool
=
False
,
enable_eplb
:
bool
=
False
,
enable_tpsp
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -137,6 +196,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -137,6 +196,7 @@ class DeepseekV2MoE(nn.Module):
raise
ValueError
(
f
"Unsupported activation:
{
config
.
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
config
.
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
self
.
enable_tpsp
=
enable_tpsp
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
n_routed_experts
,
config
.
n_routed_experts
,
bias
=
False
,
bias
=
False
,
...
@@ -182,28 +242,93 @@ class DeepseekV2MoE(nn.Module):
...
@@ -182,28 +242,93 @@ class DeepseekV2MoE(nn.Module):
num_redundant_experts
=
self
.
n_redundant_experts
,
num_redundant_experts
=
self
.
n_redundant_experts
,
routed_scaling_factor
=
self
.
routed_scaling_factor
)
routed_scaling_factor
=
self
.
routed_scaling_factor
)
self
.
aux_stream
=
trt_aux_stream_dict
[
AuxStreamType
.
MoeShared
]
self
.
event_dict
=
{
key
:
torch
.
cuda
.
Event
()
for
key
in
[
EventType
.
Main
,
EventType
.
MoeShared
,
EventType
.
MoeAllgather
,
EventType
.
MoeReduceScatter
]
}
if
config
.
n_shared_experts
is
not
None
:
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
config
.
n_shared_experts
)
self
.
shared_experts
=
DeepseekV2MLP
(
if
self
.
enable_tpsp
:
hidden_size
=
config
.
hidden_size
,
self
.
shared_experts
=
SharedExpertOverlapSPMLP
(
intermediate_size
=
intermediate_size
,
hidden_size
=
config
.
hidden_size
,
hidden_act
=
config
.
hidden_act
,
intermediate_size
=
intermediate_size
,
quant_config
=
quant_config
,
hidden_act
=
config
.
hidden_act
,
reduce_results
=
self
.
experts
.
must_reduce_shared_expert_outputs
(
quant_config
=
quant_config
,
),
reduce_results
=
False
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
event_dict
=
self
.
event_dict
,
aux_stream
=
self
.
aux_stream
)
else
:
self
.
shared_experts
=
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
self
.
experts
.
must_reduce_shared_expert_outputs
(
),
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
def
tpsp_forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
old_hidden_states
=
hidden_states
router_logits
,
_
=
self
.
gate
(
hidden_states
)
self
.
event_dict
[
EventType
.
MoeAllgather
].
record
()
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
router_logits
=
tensor_model_parallel_all_gather
(
router_logits
.
contiguous
(),
dim
=
0
)
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
self
.
event_dict
[
EventType
.
MoeReduceScatter
].
record
()
final_hidden_states
=
tensor_model_parallel_reduce_scatter
(
final_hidden_states
.
contiguous
(),
dim
=
0
)
shared_output
=
None
if
self
.
n_shared_experts
is
not
None
:
with
torch
.
cuda
.
stream
(
self
.
aux_stream
):
shared_output
=
self
.
shared_experts
(
old_hidden_states
)
self
.
event_dict
[
EventType
.
MoeShared
].
wait
()
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
return
final_hidden_states
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
residual
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
enable_tpsp
:
return
self
.
tpsp_forward
(
hidden_states
)
is_graph_capturing
=
True
do_multi_stream
=
is_graph_capturing
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
do_multi_stream
:
self
.
event_dict
[
EventType
.
Main
].
record
()
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
...
@@ -211,33 +336,51 @@ class DeepseekV2MoE(nn.Module):
...
@@ -211,33 +336,51 @@ class DeepseekV2MoE(nn.Module):
else
:
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
if
do_multi_stream
:
router_logits
,
_
=
self
.
gate
(
hidden_states
)
with
torch
.
cuda
.
stream
(
self
.
aux_stream
):
self
.
event_dict
[
EventType
.
Main
].
wait
()
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
# router_logits: (num_tokens, n_experts)
final_hidden_states
=
self
.
experts
(
router_logits
,
_
=
self
.
gate
(
hidden_states
)
hidden_states
=
hidden_states
,
if
hidden_states
.
dtype
!=
torch
.
float16
:
router_logits
=
router_logits
,
final_hidden_states
=
self
.
experts
(
shared_output
=
shared_output
)
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
self
.
event_dict
[
EventType
.
MoeShared
].
record
()
self
.
event_dict
[
EventType
.
MoeShared
].
wait
()
else
:
else
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
router_logits
=
router_logits
,
shared_output
=
shared_output
)
else
:
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
else
:
else
:
# Fix FP16 overflow
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
*
(
1.
/
self
.
routed_scaling_factor
)
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
if
envs
.
VLLM_ENABLE_TBO
:
...
@@ -424,12 +567,15 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -424,12 +567,15 @@ class DeepseekV2MLAAttention(nn.Module):
v_head_dim
:
int
,
v_head_dim
:
int
,
q_lora_rank
:
Optional
[
int
],
q_lora_rank
:
Optional
[
int
],
kv_lora_rank
:
int
,
kv_lora_rank
:
int
,
layer_idx
:
int
,
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
trt_aux_stream_dict
:
Dict
[
AuxStreamType
,
torch
.
cuda
.
Stream
]
=
{},
enable_tpsp
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -449,6 +595,8 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -449,6 +595,8 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
layer_idx
=
layer_idx
self
.
enable_tpsp
=
enable_tpsp
if
self
.
q_lora_rank
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
...
@@ -489,12 +637,21 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -489,12 +637,21 @@ class DeepseekV2MLAAttention(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_proj"
)
prefix
=
f
"
{
prefix
}
.q_proj"
)
self
.
kv_a_proj_with_mqa
=
ReplicatedLinear
(
if
not
envs
.
VLLM_ENABLE_MLA_QKV_MERGE
:
self
.
hidden_size
,
self
.
kv_a_proj_with_mqa
=
ReplicatedLinear
(
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
self
.
hidden_size
,
bias
=
False
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
quant_config
=
quant_config
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.kv_a_proj_with_mqa"
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.kv_a_proj_with_mqa"
)
else
:
self
.
q_a_and_kv_a_proj
=
MergedReplicatedLinear
(
self
.
hidden_size
,
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_a_and_kv_a_proj"
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
self
.
kv_b_proj
=
ColumnParallelLinear
(
...
@@ -507,7 +664,8 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -507,7 +664,8 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
)
prefix
=
f
"
{
prefix
}
.o_proj"
,
sp_parallel
=
self
.
enable_tpsp
)
if
rope_scaling
:
if
rope_scaling
:
rope_scaling
[
"rope_type"
]
=
'deepseek_yarn'
rope_scaling
[
"rope_type"
]
=
'deepseek_yarn'
...
@@ -550,6 +708,11 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -550,6 +708,11 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
prefix
=
prefix
self
.
prefix
=
prefix
self
.
debug_layer_idx
=
int
(
self
.
prefix
.
split
(
"."
)[
-
2
])
self
.
debug_layer_idx
=
int
(
self
.
prefix
.
split
(
"."
)[
-
2
])
self
.
aux_stream
=
trt_aux_stream_dict
[
AuxStreamType
.
Attention
]
self
.
event_dict
=
{
key
:
torch
.
cuda
.
Event
()
for
key
in
[
EventType
.
QCAllgather
,
EventType
.
KVFinish
]
}
def
forward
(
def
forward
(
self
,
self
,
...
@@ -588,34 +751,98 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -588,34 +751,98 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
num_local_heads
*
self
.
v_head_dim
))
self
.
num_local_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
],
new_residual
return
self
.
o_proj
(
attn_out
)[
0
],
new_residual
else
:
else
:
if
self
.
q_lora_rank
is
not
None
:
if
not
self
.
enable_tpsp
:
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
if
not
envs
.
VLLM_ENABLE_MLA_QKV_MERGE
:
q_c
=
self
.
q_a_layernorm
(
q_c
)
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_b_proj
(
q_c
)[
0
]
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
else
:
q_c
=
self
.
q_a_layernorm
(
q_c
)
q
=
self
.
q_proj
(
hidden_states
)[
0
]
q
=
self
.
q_b_proj
(
q_c
)[
0
]
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
if
envs
.
VLLM_USE_LIGHTOP
:
kv_c_normed
=
self
.
kv_a_layernorm
.
forward_cuda_opt
(
kv_c
)
else
:
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
]
else
:
if
self
.
q_lora_rank
is
not
None
:
qkv_lora
=
self
.
q_a_and_kv_a_proj
(
hidden_states
)[
0
]
q_c
,
kv_lora
=
qkv_lora
.
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
,
)
q_c
=
self
.
q_a_layernorm
(
q_c
)
q
=
self
.
q_b_proj
(
q_c
)[
0
]
else
:
hidden_states_or_q_c
=
hidden_states
kv_lora
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
kv_lora
.
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
)
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
]
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
self
.
event_dict
[
EventType
.
QCAllgather
].
record
()
q_c
=
self
.
q_a_layernorm
(
q_c
)
if
self
.
layer_idx
>
0
:
q_c
=
tensor_model_parallel_all_gather
(
q_c
.
contiguous
(),
dim
=
0
)
with
torch
.
cuda
.
stream
(
self
.
aux_stream
):
self
.
event_dict
[
EventType
.
QCAllgather
].
wait
()
kv_a_out
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
if
self
.
layer_idx
>
0
:
kv_a_out
=
tensor_model_parallel_all_gather
(
kv_a_out
.
contiguous
(),
dim
=
0
)
kv_c
,
k_pe
=
kv_a_out
.
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
if
envs
.
VLLM_USE_LIGHTOP
:
kv_c_normed
=
self
.
kv_a_layernorm
.
forward_cuda_opt
(
kv_c
)
else
:
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
self
.
event_dict
[
EventType
.
KVFinish
].
record
()
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
attn_out
=
self
.
mla_attn
(
q
=
self
.
q_b_proj
(
q_c
)[
0
]
q
,
kv_c_normed
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
]
self
.
event_dict
[
EventType
.
KVFinish
].
wait
()
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
k_pe
,
output_shape
=
(
kv_a_out
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
]
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
...
@@ -627,6 +854,8 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -627,6 +854,8 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
enable_eplb
:
bool
=
False
,
enable_eplb
:
bool
=
False
,
trt_aux_stream_dict
:
Dict
[
AuxStreamType
,
torch
.
cuda
.
Stream
]
=
{},
mtp_layer
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -638,6 +867,9 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -638,6 +867,9 @@ class DeepseekV2DecoderLayer(nn.Module):
# with the layer's index.
# with the layer's index.
layer_idx
=
int
(
prefix
.
split
(
sep
=
'.'
)[
-
1
])
layer_idx
=
int
(
prefix
.
split
(
sep
=
'.'
)[
-
1
])
self
.
layer_idx
=
layer_idx
self
.
layer_idx
=
layer_idx
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
enable_tpsp
=
envs
.
VLLM_ENABLE_MLA_SP
and
self
.
tp_size
>
1
and
not
mtp_layer
if
model_config
.
use_mla
:
if
model_config
.
use_mla
:
attn_cls
=
DeepseekV2MLAAttention
attn_cls
=
DeepseekV2MLAAttention
else
:
else
:
...
@@ -658,6 +890,8 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -658,6 +890,8 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
trt_aux_stream_dict
=
trt_aux_stream_dict
,
enable_tpsp
=
self
.
enable_tpsp
,
)
)
if
(
config
.
n_routed_experts
is
not
None
if
(
config
.
n_routed_experts
is
not
None
...
@@ -668,6 +902,8 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -668,6 +902,8 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
prefix
=
f
"
{
prefix
}
.mlp"
,
enable_eplb
=
enable_eplb
,
enable_eplb
=
enable_eplb
,
trt_aux_stream_dict
=
trt_aux_stream_dict
,
enable_tpsp
=
self
.
enable_tpsp
)
)
else
:
else
:
self
.
mlp
=
DeepseekV2MLP
(
self
.
mlp
=
DeepseekV2MLP
(
...
@@ -676,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -676,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
prefix
=
f
"
{
prefix
}
.mlp"
,
enable_tpsp
=
self
.
enable_tpsp
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
...
@@ -758,6 +995,11 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -758,6 +995,11 @@ class DeepseekV2DecoderLayer(nn.Module):
# first layer.
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
residual
*=
1.
/
self
.
routed_scaling_factor
# split residual into sp piece
if
self
.
layer_idx
==
0
and
self
.
enable_tpsp
:
residual_per_rank
=
torch
.
chunk
(
residual
,
chunks
=
self
.
tp_size
,
dim
=
0
)
residual
=
residual_per_rank
[
self
.
tp_rank
]
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
...
@@ -774,7 +1016,6 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -774,7 +1016,6 @@ class DeepseekV2DecoderLayer(nn.Module):
return
hidden_states
,
residual
return
hidden_states
,
residual
@
support_torch_compile
@
support_torch_compile
class
DeepseekV2Model
(
nn
.
Module
):
class
DeepseekV2Model
(
nn
.
Module
):
...
@@ -789,8 +1030,19 @@ class DeepseekV2Model(nn.Module):
...
@@ -789,8 +1030,19 @@ class DeepseekV2Model(nn.Module):
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
enable_eplb
=
vllm_config
.
parallel_config
.
enable_eplb
enable_eplb
=
vllm_config
.
parallel_config
.
enable_eplb
self
.
config
=
config
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
aux_stream_dict
=
{
key
:
torch
.
cuda
.
Stream
()
for
key
in
[
AuxStreamType
.
Attention
,
AuxStreamType
.
MoeShared
,
AuxStreamType
.
MoeChunkingOverlap
]
}
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
...
@@ -810,9 +1062,12 @@ class DeepseekV2Model(nn.Module):
...
@@ -810,9 +1062,12 @@ class DeepseekV2Model(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
enable_eplb
=
enable_eplb
,
enable_eplb
=
enable_eplb
,
trt_aux_stream_dict
=
self
.
aux_stream_dict
,
),
),
prefix
=
f
"
{
prefix
}
.layers"
)
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
enable_tpsp
=
envs
.
VLLM_ENABLE_MLA_SP
and
self
.
tp_size
>
1
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
else
:
...
@@ -823,7 +1078,7 @@ class DeepseekV2Model(nn.Module):
...
@@ -823,7 +1078,7 @@ class DeepseekV2Model(nn.Module):
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -845,6 +1100,27 @@ class DeepseekV2Model(nn.Module):
...
@@ -845,6 +1100,27 @@ class DeepseekV2Model(nn.Module):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
# padding tpsq bs to tp_size
tpsp_bs_pad
=
False
bs
=
input_ids
.
shape
[
0
]
bs_per_rank
=
(
bs
+
self
.
tp_size
-
1
)
//
self
.
tp_size
pad_bs
=
bs_per_rank
*
self
.
tp_size
if
bs
%
self
.
tp_size
!=
0
else
bs
if
self
.
enable_tpsp
and
pad_bs
!=
bs
:
tpsp_bs_pad
=
True
additional_hidden_state
=
torch
.
zeros
(
pad_bs
-
bs
,
hidden_states
.
shape
[
1
],
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
pad_hidden_state
=
torch
.
cat
([
hidden_states
,
additional_hidden_state
],
dim
=
0
).
contiguous
()
hidden_states
=
pad_hidden_state
if
residual
:
additional_residual
=
torch
.
zeros
(
pad_bs
-
bs
,
residual
.
shape
[
1
],
dtype
=
residual
.
dtype
,
device
=
residual
.
device
)
pad_residual
=
torch
.
cat
([
residual
,
additional_residual
],
dim
=
0
).
contiguous
()
residual
=
pad_residual
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"hidden_states"
:
hidden_states
,
...
@@ -990,11 +1266,20 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
...
@@ -990,11 +1266,20 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
if
not
envs
.
VLLM_ENABLE_MLA_QKV_MERGE
:
# (param_name, shard_name, shard_id)
stacked_params_mapping
=
[
(
"gate_up_proj"
,
"gate_proj"
,
0
),
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
]
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
else
:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"q_a_and_kv_a_proj"
,
"q_a_proj"
,
0
),
(
"q_a_and_kv_a_proj"
,
"kv_a_proj_with_mqa"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
# (param_name, weight_name, expert_id, shard_id)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment