Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aef3c487
Commit
aef3c487
authored
Apr 22, 2026
by
wangmin6
Committed by
zhangzbb
Apr 22, 2026
Browse files
[Feature]添加PCP功能,只支持mla架构,CPLB待验证
parent
c1819454
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
886 additions
and
47 deletions
+886
-47
vllm/config/parallel.py
vllm/config/parallel.py
+7
-0
vllm/config/vllm.py
vllm/config/vllm.py
+6
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+14
-0
vllm/envs.py
vllm/envs.py
+8
-1
vllm/forward_context.py
vllm/forward_context.py
+21
-0
vllm/model_executor/layers/mla.py
vllm/model_executor/layers/mla.py
+36
-0
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+1
-1
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+37
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+186
-7
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+38
-0
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+5
-2
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+2
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+163
-6
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+4
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+358
-29
No files found.
vllm/config/parallel.py
View file @
aef3c487
...
@@ -299,6 +299,13 @@ class ParallelConfig:
...
@@ -299,6 +299,13 @@ class ParallelConfig:
should only be set by API server scale-out.
should only be set by API server scale-out.
"""
"""
enable_lightly_cp
:
bool
=
False
"""Use lightly context parallel."""
enable_lightly_cplb
:
bool
=
False
"""Use lightly context parallel load balancing."""
@
field_validator
(
"disable_nccl_for_dp_synchronization"
,
mode
=
"wrap"
)
@
field_validator
(
"disable_nccl_for_dp_synchronization"
,
mode
=
"wrap"
)
@
classmethod
@
classmethod
def
_skip_none_validation
(
cls
,
value
:
Any
,
handler
:
Callable
)
->
Any
:
def
_skip_none_validation
(
cls
,
value
:
Any
,
handler
:
Callable
)
->
Any
:
...
...
vllm/config/vllm.py
View file @
aef3c487
...
@@ -1061,6 +1061,12 @@ class VllmConfig:
...
@@ -1061,6 +1061,12 @@ class VllmConfig:
# Handle the KV connector configs
# Handle the KV connector configs
self
.
_post_init_kv_transfer_config
()
self
.
_post_init_kv_transfer_config
()
if
self
.
parallel_config
.
enable_lightly_cp
and
not
self
.
model_config
.
enforce_eager
:
raise
ValueError
(
"Lightly context parallel currently only supports the eager mode!!!"
)
def
update_sizes_for_sequence_parallelism
(
self
,
possible_sizes
:
list
)
->
list
:
def
update_sizes_for_sequence_parallelism
(
self
,
possible_sizes
:
list
)
->
list
:
# remove the sizes that not multiple of tp_size when
# remove the sizes that not multiple of tp_size when
# enable sequence parallelism
# enable sequence parallelism
...
...
vllm/engine/arg_utils.py
View file @
aef3c487
...
@@ -585,6 +585,9 @@ class EngineArgs:
...
@@ -585,6 +585,9 @@ class EngineArgs:
kv_offloading_backend
:
KVOffloadingBackend
=
CacheConfig
.
kv_offloading_backend
kv_offloading_backend
:
KVOffloadingBackend
=
CacheConfig
.
kv_offloading_backend
tokens_only
:
bool
=
False
tokens_only
:
bool
=
False
enable_lightly_cp
:
bool
=
ParallelConfig
.
enable_lightly_cp
enable_lightly_cplb
:
bool
=
ParallelConfig
.
enable_lightly_cplb
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# support `EngineArgs(compilation_config={...})`
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
# without having to manually construct a
...
@@ -902,6 +905,15 @@ class EngineArgs:
...
@@ -902,6 +905,15 @@ class EngineArgs:
"--worker-extension-cls"
,
**
parallel_kwargs
[
"worker_extension_cls"
]
"--worker-extension-cls"
,
**
parallel_kwargs
[
"worker_extension_cls"
]
)
)
parallel_group
.
add_argument
(
"--enable-lightly-cp"
,
**
parallel_kwargs
[
"enable_lightly_cp"
],
)
parallel_group
.
add_argument
(
"--enable-lightly-cplb"
,
**
parallel_kwargs
[
"enable_lightly_cplb"
],
)
# KV cache arguments
# KV cache arguments
cache_kwargs
=
get_kwargs
(
CacheConfig
)
cache_kwargs
=
get_kwargs
(
CacheConfig
)
cache_group
=
parser
.
add_argument_group
(
cache_group
=
parser
.
add_argument_group
(
...
@@ -1661,6 +1673,8 @@ class EngineArgs:
...
@@ -1661,6 +1673,8 @@ class EngineArgs:
cp_kv_cache_interleave_size
=
self
.
cp_kv_cache_interleave_size
,
cp_kv_cache_interleave_size
=
self
.
cp_kv_cache_interleave_size
,
_api_process_count
=
self
.
_api_process_count
,
_api_process_count
=
self
.
_api_process_count
,
_api_process_rank
=
self
.
_api_process_rank
,
_api_process_rank
=
self
.
_api_process_rank
,
enable_lightly_cp
=
self
.
enable_lightly_cp
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
,
)
)
speculative_config
=
self
.
create_speculative_config
(
speculative_config
=
self
.
create_speculative_config
(
...
...
vllm/envs.py
View file @
aef3c487
...
@@ -324,6 +324,9 @@ if TYPE_CHECKING:
...
@@ -324,6 +324,9 @@ if TYPE_CHECKING:
USE_LIGHTOP_TOPK
:
bool
=
False
USE_LIGHTOP_TOPK
:
bool
=
False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX
:
bool
=
False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX
:
bool
=
False
VLLM_DISABLE_DSA
:
bool
=
False
VLLM_DISABLE_DSA
:
bool
=
False
VLLM_LIGHTLY_CP_THRESHOULD
:
int
=
2048
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
"XDG_CACHE_HOME"
,
"XDG_CACHE_HOME"
,
...
@@ -2004,7 +2007,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -2004,7 +2007,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
#If set to 1/True, disenable the DSA.
#If set to 1/True, disenable the DSA.
"VLLM_DISABLE_DSA"
:
"VLLM_DISABLE_DSA"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_DISABLE_DSA"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_DISABLE_DSA"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# MLA_CP open threshold
"VLLM_LIGHTLY_CP_THRESHOULD"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_LIGHTLY_CP_THRESHOULD"
,
"2048"
)),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/forward_context.py
View file @
aef3c487
...
@@ -240,6 +240,11 @@ class ForwardContext:
...
@@ -240,6 +240,11 @@ class ForwardContext:
additional_kwargs
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
additional_kwargs
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
enable_lightly_cp
:
bool
=
False
enable_lightly_cplb
:
bool
=
False
def
__post_init__
(
self
):
def
__post_init__
(
self
):
assert
self
.
cudagraph_runtime_mode
.
valid_runtime_modes
(),
(
assert
self
.
cudagraph_runtime_mode
.
valid_runtime_modes
(),
(
f
"Invalid cudagraph runtime mode:
{
self
.
cudagraph_runtime_mode
}
"
f
"Invalid cudagraph runtime mode:
{
self
.
cudagraph_runtime_mode
}
"
...
@@ -273,6 +278,10 @@ def create_forward_context(
...
@@ -273,6 +278,10 @@ def create_forward_context(
slot_mapping
:
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
,
slot_mapping
:
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
,
additional_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
additional_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
skip_compiled
:
bool
=
False
,
skip_compiled
:
bool
=
False
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
enable_lightly_cp
:
bool
=
False
,
enable_lightly_cplb
:
bool
=
False
):
):
if
vllm_config
.
compilation_config
.
fast_moe_cold_start
:
if
vllm_config
.
compilation_config
.
fast_moe_cold_start
:
if
vllm_config
.
speculative_config
is
None
:
if
vllm_config
.
speculative_config
is
None
:
...
@@ -298,6 +307,10 @@ def create_forward_context(
...
@@ -298,6 +307,10 @@ def create_forward_context(
batch_descriptor
=
batch_descriptor
,
batch_descriptor
=
batch_descriptor
,
ubatch_slices
=
ubatch_slices
,
ubatch_slices
=
ubatch_slices
,
skip_compiled
=
skip_compiled
,
skip_compiled
=
skip_compiled
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
enable_lightly_cp
=
enable_lightly_cp
,
enable_lightly_cplb
=
enable_lightly_cplb
,
additional_kwargs
=
additional_kwargs
or
{},
additional_kwargs
=
additional_kwargs
or
{},
)
)
...
@@ -329,6 +342,10 @@ def set_forward_context(
...
@@ -329,6 +342,10 @@ def set_forward_context(
ubatch_slices
:
UBatchSlices
|
None
=
None
,
ubatch_slices
:
UBatchSlices
|
None
=
None
,
slot_mapping
:
dict
[
str
,
torch
.
Tensor
]
|
list
[
dict
[
str
,
torch
.
Tensor
]]
|
None
=
None
,
slot_mapping
:
dict
[
str
,
torch
.
Tensor
]
|
list
[
dict
[
str
,
torch
.
Tensor
]]
|
None
=
None
,
skip_compiled
:
bool
=
False
,
skip_compiled
:
bool
=
False
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
enable_lightly_cp
:
bool
=
False
,
enable_lightly_cplb
:
bool
=
False
,
):
):
"""A context manager that stores the current forward context,
"""A context manager that stores the current forward context,
can be attention metadata, etc.
can be attention metadata, etc.
...
@@ -390,6 +407,10 @@ def set_forward_context(
...
@@ -390,6 +407,10 @@ def set_forward_context(
slot_mapping
,
slot_mapping
,
additional_kwargs
,
additional_kwargs
,
skip_compiled
,
skip_compiled
,
scatter_indexes_tensor
,
gather_indexes_tensor
,
enable_lightly_cp
,
enable_lightly_cplb
)
)
try
:
try
:
...
...
vllm/model_executor/layers/mla.py
View file @
aef3c487
...
@@ -7,8 +7,12 @@ import torch
...
@@ -7,8 +7,12 @@ import torch
from
vllm.attention.layer
import
MLAAttention
from
vllm.attention.layer
import
MLAAttention
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.custom_op
import
PluggableLayer
from
vllm.model_executor.custom_op
import
PluggableLayer
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
)
@
dataclass
@
dataclass
...
@@ -184,8 +188,26 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
...
@@ -184,8 +188,26 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if
llama_4_scaling
is
not
None
:
if
llama_4_scaling
is
not
None
:
q
*=
llama_4_scaling
q
*=
llama_4_scaling
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
enable_lightly_cplb
=
get_forward_context
().
enable_lightly_cplb
# if not use_fused_rms_rope_concat:
# if not use_fused_rms_rope_concat:
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
if
enable_lightly_cp
:
kv_c_normed
=
tensor_model_parallel_all_gather
(
kv_c_normed
.
contiguous
(),
0
)
k_pe
=
tensor_model_parallel_all_gather
(
k_pe
.
contiguous
(),
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
enable_lightly_cplb
and
gather_indexes_tensor
is
not
None
:
# Reorder kv after pcp allgather.
kv_c_normed
=
torch
.
index_select
(
kv_c_normed
,
0
,
gather_indexes_tensor
)
k_pe
=
torch
.
index_select
(
k_pe
,
0
,
gather_indexes_tensor
)
attn_out
=
self
.
mla_attn
(
attn_out
=
self
.
mla_attn
(
q
,
q
,
kv_c_normed
,
kv_c_normed
,
...
@@ -221,6 +243,20 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
...
@@ -221,6 +243,20 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires rotary_emb to "
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires rotary_emb to "
"expose 'cos_sin_cache'."
"expose 'cos_sin_cache'."
)
)
if
enable_lightly_cp
:
kv_c
=
tensor_model_parallel_all_gather
(
kv_c
.
contiguous
(),
0
)
k_pe
=
tensor_model_parallel_all_gather
(
k_pe
.
contiguous
(),
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
enable_lightly_cplb
and
gather_indexes_tensor
is
not
None
:
# Reorder kv after pcp allgather.
kv_c
=
torch
.
index_select
(
kv_c
,
0
,
gather_indexes_tensor
)
k_pe
=
torch
.
index_select
(
k_pe
,
0
,
gather_indexes_tensor
)
attn_out
=
self
.
mla_attn
(
attn_out
=
self
.
mla_attn
(
q
[...,
self
.
qk_nope_head_dim
:],
q
[...,
self
.
qk_nope_head_dim
:],
kv_c
,
kv_c
,
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
aef3c487
...
@@ -90,7 +90,7 @@ def sparse_attn_indexer(
...
@@ -90,7 +90,7 @@ def sparse_attn_indexer(
)
)
attn_metadata
=
attn_metadata
[
layer_name
]
attn_metadata
=
attn_metadata
[
layer_name
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
slot_mapping
=
attn_metadata
.
slot_mapping
[:
attn_metadata
.
num_kv_actual_tokens
]
has_decode
=
attn_metadata
.
num_decodes
>
0
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
aef3c487
...
@@ -11,6 +11,7 @@ import torch
...
@@ -11,6 +11,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.forward_context
import
get_forward_context
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -36,6 +37,9 @@ from .deepseek_v2 import (
...
@@ -36,6 +37,9 @@ from .deepseek_v2 import (
DeepseekV2MoE
,
DeepseekV2MoE
,
get_spec_layer_idx_from_weight_name
,
get_spec_layer_idx_from_weight_name
,
)
)
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
.utils
import
maybe_prefix
from
.utils
import
maybe_prefix
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -177,6 +181,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
...
@@ -177,6 +181,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
return
self
.
embed_tokens
(
input_ids
)
...
@@ -191,7 +198,28 @@ class DeepSeekMultiTokenPredictor(nn.Module):
...
@@ -191,7 +198,28 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
return
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
if
enable_lightly_cp
:
scatter_indexes_tensor
=
get_forward_context
().
scatter_indexes_tensor
if
scatter_indexes_tensor
is
None
:
inputs_embeds_per_rank
=
torch
.
chunk
(
inputs_embeds
,
chunks
=
self
.
tp_size
,
dim
=
0
)
inputs_embeds
=
inputs_embeds_per_rank
[
self
.
tp_rank
].
contiguous
()
previous_hidden_states_per_rank
=
torch
.
chunk
(
previous_hidden_states
,
chunks
=
self
.
tp_size
,
dim
=
0
)
previous_hidden_states
=
previous_hidden_states_per_rank
[
self
.
tp_rank
].
contiguous
()
if
positions
is
not
None
:
positions_per_rank
=
torch
.
chunk
(
positions
,
chunks
=
self
.
tp_size
,
dim
=
0
)
positions
=
positions_per_rank
[
self
.
tp_rank
].
contiguous
()
else
:
scatter_indexes_tensor
=
torch
.
where
(
scatter_indexes_tensor
==
-
1
,
0
,
scatter_indexes_tensor
)
inputs_embeds
=
torch
.
index_select
(
inputs_embeds
,
0
,
scatter_indexes_tensor
)
previous_hidden_states
=
torch
.
index_select
(
previous_hidden_states
,
0
,
scatter_indexes_tensor
)
if
positions
is
not
None
:
positions
=
torch
.
index_select
(
positions
,
0
,
scatter_indexes_tensor
)
hidden_states
=
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
input_ids
,
input_ids
,
positions
,
positions
,
previous_hidden_states
,
previous_hidden_states
,
...
@@ -199,6 +227,14 @@ class DeepSeekMultiTokenPredictor(nn.Module):
...
@@ -199,6 +227,14 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx
,
current_step_idx
,
)
)
if
enable_lightly_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
gather_indexes_tensor
is
not
None
:
hidden_states
=
torch
.
index_select
(
hidden_states
,
0
,
gather_indexes_tensor
)
return
hidden_states
def
compute_logits
(
def
compute_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
aef3c487
...
@@ -46,6 +46,8 @@ from vllm.distributed import (
...
@@ -46,6 +46,8 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
)
)
from
vllm.forward_context
import
get_forward_context
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
,
tensor_model_parallel_reduce_scatter
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
...
@@ -181,6 +183,44 @@ class DeepseekAttention(nn.Module):
...
@@ -181,6 +183,44 @@ class DeepseekAttention(nn.Module):
return
output
return
output
def
iqis_all_gather
(
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
tp_size
:
int
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
iqis
is
not
None
iq_tensor
,
is_tensor
=
iqis
assert
isinstance
(
iq_tensor
,
torch
.
Tensor
)
assert
isinstance
(
is_tensor
,
torch
.
Tensor
)
assert
iq_tensor
.
dtype
==
torch
.
int8
,
f
"iq_tensor dtype is
{
iq_tensor
.
dtype
}
"
assert
is_tensor
.
dtype
==
torch
.
float32
,
f
"is_tensor dtype is
{
is_tensor
.
dtype
}
"
assert
iq_tensor
.
dim
()
==
2
assert
is_tensor
.
dim
()
==
2
m_local
,
n
=
iq_tensor
.
shape
assert
is_tensor
.
shape
[
0
]
==
m_local
,
f
"
{
is_tensor
.
shape
[
0
]
}
!=
{
iq_tensor
.
shape
[
0
]
}
"
assert
is_tensor
.
shape
[
1
]
==
1
,
f
"is_tensor dim 1 =
{
is_tensor
.
shape
[
1
]
}
"
iq_int8_2d
=
iq_tensor
.
view
(
torch
.
int8
)
is_int8_2d
=
is_tensor
.
view
(
torch
.
int8
)
combined_2d
=
torch
.
cat
([
iq_int8_2d
,
is_int8_2d
],
dim
=
1
)
# [m_local, n + 4]
if
not
combined_2d
.
is_contiguous
():
combined_2d
=
combined_2d
.
contiguous
()
combined_gathered
=
tensor_model_parallel_all_gather
(
combined_2d
,
dim
=
0
)
split_idx
=
n
iq_gathered_int8
=
combined_gathered
[:,
:
split_idx
].
contiguous
()
is_gathered_int8
=
combined_gathered
[:,
split_idx
:].
contiguous
()
iq_gathered
=
iq_gathered_int8
.
view
(
torch
.
int8
)
assert
iq_gathered
.
shape
[
0
]
==
m_local
*
tp_size
,
f
"iq_gathered dim0=
{
iq_gathered
.
shape
[
0
]
}
, expected
{
m_local
*
tp_size
}
"
# is_gathered_int8 should be [m_local*tp_size, 4]
assert
is_gathered_int8
.
shape
[
0
]
==
m_local
*
tp_size
,
f
"is_gathered_int8 dim0=
{
is_gathered_int8
.
shape
[
0
]
}
, expected
{
m_local
*
tp_size
}
"
assert
is_gathered_int8
.
shape
[
1
]
==
4
,
f
"is_gathered_int8 dim1=
{
is_gathered_int8
.
shape
[
1
]
}
"
is_gathered
=
is_gathered_int8
.
view
(
torch
.
float32
)
return
(
iq_gathered
,
is_gathered
)
class
DeepseekV2MLP
(
nn
.
Module
):
class
DeepseekV2MLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -211,10 +251,85 @@ class DeepseekV2MLP(nn.Module):
...
@@ -211,10 +251,85 @@ class DeepseekV2MLP(nn.Module):
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
#reduce_results=reduce_results,
reduce_results
=
False
,
disable_tp
=
is_sequence_parallel
,
disable_tp
=
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
):
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
if
enable_lightly_cp
:
if
iqis
is
not
None
and
iqis
[
0
]
is
not
None
and
iqis
[
1
]
is
not
None
:
iqis
=
iqis_all_gather
(
iqis
,
tp_size
=
self
.
tp_size
)
else
:
x
=
tensor_model_parallel_all_gather
(
x
.
contiguous
(),
0
)
if
envs
.
USE_FUSED_RMS_QUANT
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
iqis
=
iqis
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
from
lmslim.quantize.quant_ops
import
lm_fuse_silu_mul_quant
xq
,
xs
=
lm_fuse_silu_mul_quant
(
gate_up
)
x
,
_
=
self
.
down_proj
(
gate_up
,
iqis
=
(
xq
,
xs
))
else
:
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
else
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
if
enable_lightly_cp
:
x
=
tensor_model_parallel_reduce_scatter
(
x
.
contiguous
(),
dim
=
0
)
return
x
elif
self
.
tp_size
>
1
:
x
=
tensor_model_parallel_all_reduce
(
x
)
return
x
class
DeepseekV2SharedMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
QuantizationConfig
|
None
=
None
,
reduce_results
:
bool
=
True
,
is_sequence_parallel
=
False
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
disable_tp
=
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
disable_tp
=
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
...
@@ -311,7 +426,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -311,7 +426,7 @@ class DeepseekV2MoE(nn.Module):
else
:
else
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
DeepseekV2MLP
(
self
.
shared_experts
=
DeepseekV2
Shared
MLP
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
...
@@ -357,6 +472,11 @@ class DeepseekV2MoE(nn.Module):
...
@@ -357,6 +472,11 @@ class DeepseekV2MoE(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
if
enable_lightly_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
0
)
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
...
@@ -428,7 +548,12 @@ class DeepseekV2MoE(nn.Module):
...
@@ -428,7 +548,12 @@ class DeepseekV2MoE(nn.Module):
assert
shared_output
is
not
None
assert
shared_output
is
not
None
final_hidden_states
+=
shared_output
final_hidden_states
+=
shared_output
if
self
.
is_sequence_parallel
:
if
enable_lightly_cp
:
final_hidden_states
=
tensor_model_parallel_reduce_scatter
(
final_hidden_states
.
contiguous
(),
0
)
return
final_hidden_states
elif
self
.
is_sequence_parallel
:
final_hidden_states
=
tensor_model_parallel_all_gather
(
final_hidden_states
=
tensor_model_parallel_all_gather
(
final_hidden_states
,
0
final_hidden_states
,
0
)
)
...
@@ -759,6 +884,16 @@ class Indexer(nn.Module):
...
@@ -759,6 +884,16 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
if
enable_lightly_cp
:
k
=
tensor_model_parallel_all_gather
(
k
.
contiguous
(),
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
enable_lightly_cplb
=
get_forward_context
().
enable_lightly_cplb
if
enable_lightly_cplb
and
gather_indexes_tensor
is
not
None
:
k
=
torch
.
index_select
(
k
,
0
,
gather_indexes_tensor
)
# we only quant q here since k quant is fused with cache insertion
# we only quant q here since k quant is fused with cache insertion
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
q
=
q
.
view
(
-
1
,
self
.
head_dim
)
q
=
q
.
view
(
-
1
,
self
.
head_dim
)
...
@@ -825,7 +960,8 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -825,7 +960,8 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
assert
num_heads
%
tp_size
==
0
assert
num_heads
%
tp_size
==
0
self
.
num_local_heads
=
num_heads
//
tp_size
self
.
num_local_heads
=
num_heads
//
tp_size
if
not
\
vllm_config
.
parallel_config
.
enable_lightly_cp
else
self
.
num_heads
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
...
@@ -859,6 +995,7 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -859,6 +995,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_b_proj"
,
prefix
=
f
"
{
prefix
}
.q_b_proj"
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
)
)
else
:
else
:
self
.
q_proj
=
ColumnParallelLinear
(
self
.
q_proj
=
ColumnParallelLinear
(
...
@@ -867,6 +1004,7 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -867,6 +1004,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_proj"
,
prefix
=
f
"
{
prefix
}
.q_proj"
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
,
)
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
self
.
kv_b_proj
=
ColumnParallelLinear
(
...
@@ -875,6 +1013,7 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -875,6 +1013,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.kv_b_proj"
,
prefix
=
f
"
{
prefix
}
.kv_b_proj"
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
v_head_dim
,
self
.
num_heads
*
self
.
v_head_dim
,
...
@@ -882,6 +1021,7 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -882,6 +1021,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
,
)
)
if
config
.
rope_parameters
[
"rope_type"
]
!=
"default"
:
if
config
.
rope_parameters
[
"rope_type"
]
!=
"default"
:
...
@@ -1118,6 +1258,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1118,6 +1258,7 @@ class DeepseekV2DecoderLayer(nn.Module):
residual
*=
1.0
/
self
.
routed_scaling_factor
residual
*=
1.0
/
self
.
routed_scaling_factor
# Fully Connected
# Fully Connected
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
update_hs
=
True
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
else
False
update_hs
=
True
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
else
False
assert
self
.
post_attention_layernorm
.
has_weight
is
True
assert
self
.
post_attention_layernorm
.
has_weight
is
True
_i_q
,
_i_s
,
residual
=
self
.
post_attention_layernorm
(
x
=
hidden_states
,
_i_q
,
_i_s
,
residual
=
self
.
post_attention_layernorm
(
x
=
hidden_states
,
...
@@ -1126,9 +1267,10 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1126,9 +1267,10 @@ class DeepseekV2DecoderLayer(nn.Module):
update_input
=
update_hs
update_input
=
update_hs
)
)
new_resi
=
residual
new_resi
=
residual
hidden_states
=
self
.
mlp
(
hidden_states
,
if
enable_lightly_cp
and
isinstance
(
self
.
mlp
,
DeepseekV2MoE
):
iqis
=
(
_i_q
,
_i_s
)
hidden_states
=
self
.
mlp
(
hidden_states
)
)
else
:
hidden_states
=
self
.
mlp
(
hidden_states
,
iqis
=
(
_i_q
,
_i_s
))
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# Fix FP16 overflow
...
@@ -1225,6 +1367,9 @@ class DeepseekV2Model(nn.Module):
...
@@ -1225,6 +1367,9 @@ class DeepseekV2Model(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
device
=
current_platform
.
device_type
self
.
device
=
current_platform
.
device_type
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
#添加判断,默认开启DSA
#添加判断,默认开启DSA
force_disable_dsa
=
envs
.
VLLM_DISABLE_DSA
force_disable_dsa
=
envs
.
VLLM_DISABLE_DSA
...
@@ -1287,6 +1432,30 @@ class DeepseekV2Model(nn.Module):
...
@@ -1287,6 +1432,30 @@ class DeepseekV2Model(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
if
enable_lightly_cp
:
scatter_indexes_tensor
=
get_forward_context
().
scatter_indexes_tensor
if
scatter_indexes_tensor
is
None
:
hidden_states_per_rank
=
torch
.
chunk
(
hidden_states
,
chunks
=
self
.
tp_size
,
dim
=
0
)
hidden_states
=
hidden_states_per_rank
[
self
.
tp_rank
].
contiguous
()
if
residual
is
not
None
:
residual_per_rank
=
torch
.
chunk
(
residual
,
chunks
=
self
.
tp_size
,
dim
=
0
)
residual
=
residual_per_rank
[
self
.
tp_rank
].
contiguous
()
if
positions
is
not
None
:
positions_per_rank
=
torch
.
chunk
(
positions
,
chunks
=
self
.
tp_size
,
dim
=
0
)
positions
=
positions_per_rank
[
self
.
tp_rank
].
contiguous
()
else
:
scatter_indexes_tensor
=
torch
.
where
(
scatter_indexes_tensor
==
-
1
,
0
,
scatter_indexes_tensor
)
hidden_states
=
torch
.
index_select
(
hidden_states
,
0
,
scatter_indexes_tensor
)
if
residual
is
not
None
:
residual
=
torch
.
index_select
(
residual
,
0
,
scatter_indexes_tensor
)
if
positions
is
not
None
:
positions
=
torch
.
index_select
(
positions
,
0
,
scatter_indexes_tensor
)
# Compute llama 4 scaling once per forward pass if enabled
# Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config
=
getattr
(
self
.
config
,
"llama_4_scaling"
,
None
)
llama_4_scaling_config
=
getattr
(
self
.
config
,
"llama_4_scaling"
,
None
)
llama_4_scaling
:
torch
.
Tensor
|
None
llama_4_scaling
:
torch
.
Tensor
|
None
...
@@ -1307,11 +1476,21 @@ class DeepseekV2Model(nn.Module):
...
@@ -1307,11 +1476,21 @@ class DeepseekV2Model(nn.Module):
)
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
if
enable_lightly_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
residual
=
tensor_model_parallel_all_gather
(
residual
.
contiguous
(),
dim
=
0
)
return
IntermediateTensors
(
return
IntermediateTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
}
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
}
)
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
enable_lightly_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
gather_indexes_tensor
is
not
None
:
hidden_states
=
torch
.
index_select
(
hidden_states
,
0
,
gather_indexes_tensor
)
return
hidden_states
return
hidden_states
...
...
vllm/v1/attention/backend.py
View file @
aef3c487
...
@@ -282,6 +282,35 @@ class AttentionMetadata:
...
@@ -282,6 +282,35 @@ class AttentionMetadata:
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadata
)
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadata
)
@
dataclass
class
CpCommonAttentionMetadata
:
# sp related metadata
query_start_loc
:
torch
.
Tensor
query_start_loc_cpu
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
_seq_lens_cpu
:
torch
.
Tensor
num_actual_tokens
:
int
num_kv_actual_tokens
:
int
max_query_len
:
int
max_seq_len
:
int
num_reqs
:
int
req_ids
:
list
[
str
]
block_table_tensor
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
_num_computed_tokens_cpu
:
torch
.
Tensor
dcp_local_seq_lens
:
torch
.
Tensor
|
None
=
None
dcp_local_seq_lens_cpu
:
torch
.
Tensor
|
None
=
None
def
batch_size
(
self
)
->
int
:
return
self
.
seq_lens
.
shape
[
0
]
@
property
def
seq_lens_cpu
(
self
)
->
torch
.
Tensor
:
if
self
.
_seq_lens_cpu
is
None
:
self
.
_seq_lens_cpu
=
self
.
seq_lens
.
to
(
"cpu"
)
return
self
.
_seq_lens_cpu
@
dataclass
@
dataclass
class
CommonAttentionMetadata
:
class
CommonAttentionMetadata
:
...
@@ -312,6 +341,14 @@ class CommonAttentionMetadata:
...
@@ -312,6 +341,14 @@ class CommonAttentionMetadata:
block_table_tensor
:
torch
.
Tensor
block_table_tensor
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
num_kv_actual_tokens
:
int
|
None
=
None
seq_indexes_list
:
list
[
int
]
|
None
=
None
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
cp_common_metadata
:
CpCommonAttentionMetadata
|
None
=
None
enable_lightly_cp
:
bool
=
False
causal
:
bool
=
True
causal
:
bool
=
True
# Needed by FastPrefillAttentionBuilder
# Needed by FastPrefillAttentionBuilder
...
@@ -396,6 +433,7 @@ class CommonAttentionMetadata:
...
@@ -396,6 +433,7 @@ class CommonAttentionMetadata:
else
None
,
else
None
,
num_reqs
=
num_actual_reqs
,
num_reqs
=
num_actual_reqs
,
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
num_kv_actual_tokens
=
num_actual_tokens
,
max_query_len
=
self
.
max_query_len
,
max_query_len
=
self
.
max_query_len
,
max_seq_len
=
self
.
max_seq_len
,
max_seq_len
=
self
.
max_seq_len
,
block_table_tensor
=
self
.
block_table_tensor
[:
num_actual_reqs
],
block_table_tensor
=
self
.
block_table_tensor
[:
num_actual_reqs
],
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
aef3c487
...
@@ -138,6 +138,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
...
@@ -138,6 +138,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
max_seq_len
:
int
max_seq_len
:
int
num_actual_tokens
:
int
# Number of tokens excluding padding.
num_actual_tokens
:
int
# Number of tokens excluding padding.
num_kv_actual_tokens
:
int
query_start_loc
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
...
@@ -693,6 +694,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
...
@@ -693,6 +694,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
max_query_len
=
cm
.
max_query_len
,
max_query_len
=
cm
.
max_query_len
,
max_seq_len
=
cm
.
max_seq_len
,
max_seq_len
=
cm
.
max_seq_len
,
num_actual_tokens
=
cm
.
num_actual_tokens
,
num_actual_tokens
=
cm
.
num_actual_tokens
,
num_kv_actual_tokens
=
cm
.
num_kv_actual_tokens
,
query_start_loc
=
cm
.
query_start_loc
,
query_start_loc
=
cm
.
query_start_loc
,
slot_mapping
=
cm
.
slot_mapping
,
slot_mapping
=
cm
.
slot_mapping
,
block_table
=
cm
.
block_table_tensor
,
block_table
=
cm
.
block_table_tensor
,
...
@@ -1024,12 +1026,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
...
@@ -1024,12 +1026,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
return
output
.
fill_
(
0
)
return
output
.
fill_
(
0
)
num_actual_toks
=
attn_metadata
.
num_actual_tokens
num_actual_toks
=
attn_metadata
.
num_actual_tokens
num_kv_actual_toks
=
attn_metadata
.
num_kv_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
# Inputs and outputs may be padded for CUDA graphs
q
=
q
[:
num_actual_toks
,
...]
q
=
q
[:
num_actual_toks
,
...]
k_c_normed
=
k_c_normed
[:
num_actual_toks
,
...]
k_c_normed
=
k_c_normed
[:
num_
kv_
actual_toks
,
...]
k_pe
=
k_pe
[:
num_actual_toks
,
...]
k_pe
=
k_pe
[:
num_
kv_
actual_toks
,
...]
assert
self
.
topk_indices_buffer
is
not
None
assert
self
.
topk_indices_buffer
is
not
None
topk_indices
=
self
.
topk_indices_buffer
[:
num_actual_toks
]
topk_indices
=
self
.
topk_indices_buffer
[:
num_actual_toks
]
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
aef3c487
...
@@ -105,6 +105,7 @@ class DeepseekV32IndexerMetadata:
...
@@ -105,6 +105,7 @@ class DeepseekV32IndexerMetadata:
max_seq_len
:
int
max_seq_len
:
int
num_actual_tokens
:
int
# Number of tokens excluding padding.
num_actual_tokens
:
int
# Number of tokens excluding padding.
num_kv_actual_tokens
:
int
query_start_loc
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
# The dimension of the attention heads
# The dimension of the attention heads
...
@@ -438,6 +439,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
...
@@ -438,6 +439,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
max_query_len
=
common_attn_metadata
.
max_query_len
,
max_query_len
=
common_attn_metadata
.
max_query_len
,
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
,
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
,
num_kv_actual_tokens
=
common_attn_metadata
.
num_kv_actual_tokens
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
head_dim
=
128
,
head_dim
=
128
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
aef3c487
...
@@ -14,7 +14,7 @@ from vllm.config import (
...
@@ -14,7 +14,7 @@ from vllm.config import (
VllmConfig
,
VllmConfig
,
get_layers_from_vllm_config
,
get_layers_from_vllm_config
,
)
)
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tensor_model_parallel_rank
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
...
@@ -29,6 +29,7 @@ from vllm.utils.platform_utils import is_pin_memory_available
...
@@ -29,6 +29,7 @@ from vllm.utils.platform_utils import is_pin_memory_available
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
CpCommonAttentionMetadata
,
)
)
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.tree_attn
import
(
from
vllm.v1.attention.backends.tree_attn
import
(
...
@@ -48,6 +49,7 @@ from vllm.v1.spec_decode.utils import (
...
@@ -48,6 +49,7 @@ from vllm.v1.spec_decode.utils import (
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.utils.math_utils
import
cdiv
,
round_up
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -76,7 +78,9 @@ class SpecDecodeBaseProposer:
...
@@ -76,7 +78,9 @@ class SpecDecodeBaseProposer:
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
# The drafter can get longer sequences than the target model.
# The drafter can get longer sequences than the target model.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
if
not
\
vllm_config
.
parallel_config
.
enable_lightly_cplb
\
else
vllm_config
.
scheduler_config
.
max_num_seqs
*
2
self
.
max_num_tokens
=
(
self
.
max_num_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
)
)
...
@@ -219,6 +223,28 @@ class SpecDecodeBaseProposer:
...
@@ -219,6 +223,28 @@ class SpecDecodeBaseProposer:
1
,
len
(
self
.
tree_choices
)
+
1
,
device
=
device
,
dtype
=
torch
.
int32
1
,
len
(
self
.
tree_choices
)
+
1
,
device
=
device
,
dtype
=
torch
.
int32
).
repeat
(
max_batch_size
,
1
)
).
repeat
(
max_batch_size
,
1
)
self
.
scatter_indexes_tensor
=
None
self
.
gather_indexes_tensor
=
None
self
.
enable_lightly_cp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
self
.
enable_lightly_cplb
=
self
.
enable_lightly_cp
and
vllm_config
.
parallel_config
.
enable_lightly_cplb
if
self
.
enable_lightly_cp
:
self
.
query_start_loc
=
CpuGpuBuffer
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
pin_memory
=
is_pin_memory_available
(),
device
=
device
,
with_numpy
=
True
,
)
self
.
seq_lens
=
CpuGpuBuffer
(
max_batch_size
,
dtype
=
torch
.
int32
,
pin_memory
=
is_pin_memory_available
(),
device
=
device
,
with_numpy
=
True
,
)
def
_get_positions
(
self
,
num_tokens
:
int
):
def
_get_positions
(
self
,
num_tokens
:
int
):
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
return
self
.
mrope_positions
[:,
:
num_tokens
]
return
self
.
mrope_positions
[:,
:
num_tokens
]
...
@@ -270,6 +296,10 @@ class SpecDecodeBaseProposer:
...
@@ -270,6 +296,10 @@ class SpecDecodeBaseProposer:
self
.
cudagraph_dispatcher
.
initialize_cudagraph_keys
(
eagle_cudagraph_mode
)
self
.
cudagraph_dispatcher
.
initialize_cudagraph_keys
(
eagle_cudagraph_mode
)
def
_pad_for_mla_cp
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
return
round_up
(
num_scheduled_tokens
,
tp_size
)
def
propose
(
def
propose
(
self
,
self
,
# [num_tokens]
# [num_tokens]
...
@@ -309,6 +339,31 @@ class SpecDecodeBaseProposer:
...
@@ -309,6 +339,31 @@ class SpecDecodeBaseProposer:
)
)
)
)
num_tokens_dp_padded
,
num_tokens_across_dp
=
self
.
_pad_batch_across_dp
(
num_tokens_unpadded
=
num_tokens
,
num_tokens_padded
=
num_tokens
)
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens
>
self
.
runner
.
lightly_cp_threshould
if
enable_lightly_cp
:
num_tokens_dp_padded
=
self
.
_pad_for_mla_cp
(
num_tokens_dp_padded
)
common_attn_metadata
=
self
.
_prepare_cp_metadata
(
num_reqs_padded
=
common_attn_metadata
.
num_reqs
,
max_query_len
=
common_attn_metadata
.
max_query_len
,
max_seq_len
=
common_attn_metadata
.
seq_lens_cpu
.
max
().
item
(),
num_tokens
=
num_tokens
,
block_table_gid_0
=
common_attn_metadata
.
block_table_tensor
,
slot_mapping_gid_0
=
common_attn_metadata
.
slot_mapping
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
,
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
,
seq_lens
=
common_attn_metadata
.
seq_lens
,
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
,
num_computed_tokens_cpu
=
common_attn_metadata
.
_num_computed_tokens_cpu
,
)
self
.
scatter_indexes_tensor
=
common_attn_metadata
.
scatter_indexes_tensor
self
.
gather_indexes_tensor
=
common_attn_metadata
.
gather_indexes_tensor
assert
self
.
runner
is
not
None
assert
self
.
runner
is
not
None
if
self
.
attn_metadata_builder
is
None
:
if
self
.
attn_metadata_builder
is
None
:
...
@@ -339,10 +394,6 @@ class SpecDecodeBaseProposer:
...
@@ -339,10 +394,6 @@ class SpecDecodeBaseProposer:
assert
draft_indexer_metadata
is
not
None
assert
draft_indexer_metadata
is
not
None
per_layer_attn_metadata
[
layer_name
]
=
draft_indexer_metadata
per_layer_attn_metadata
[
layer_name
]
=
draft_indexer_metadata
num_tokens_dp_padded
,
num_tokens_across_dp
=
self
.
_pad_batch_across_dp
(
num_tokens_unpadded
=
num_tokens
,
num_tokens_padded
=
num_tokens
)
cudagraph_runtime_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
cudagraph_runtime_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
num_tokens_dp_padded
num_tokens_dp_padded
)
)
...
@@ -387,6 +438,10 @@ class SpecDecodeBaseProposer:
...
@@ -387,6 +438,10 @@ class SpecDecodeBaseProposer:
slot_mapping
=
self
.
_get_slot_mapping
(
slot_mapping
=
self
.
_get_slot_mapping
(
num_input_tokens
,
common_attn_metadata
.
slot_mapping
num_input_tokens
,
common_attn_metadata
.
slot_mapping
),
),
scatter_indexes_tensor
=
self
.
scatter_indexes_tensor
,
gather_indexes_tensor
=
self
.
gather_indexes_tensor
,
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens
>
self
.
runner
.
lightly_cp_threshould
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
):
):
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
if
not
self
.
model_returns_tuple
():
if
not
self
.
model_returns_tuple
():
...
@@ -463,6 +518,9 @@ class SpecDecodeBaseProposer:
...
@@ -463,6 +518,9 @@ class SpecDecodeBaseProposer:
if
batch_size_across_dp
is
not
None
:
if
batch_size_across_dp
is
not
None
:
batch_size_across_dp
[
self
.
dp_rank
]
=
input_batch_size
batch_size_across_dp
[
self
.
dp_rank
]
=
input_batch_size
if
enable_lightly_cp
:
common_attn_metadata
=
common_attn_metadata
.
cp_common_metadata
common_attn_metadata
.
num_actual_tokens
=
batch_size
common_attn_metadata
.
num_actual_tokens
=
batch_size
common_attn_metadata
.
max_query_len
=
1
common_attn_metadata
.
max_query_len
=
1
common_attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
common_attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
...
@@ -802,6 +860,7 @@ class SpecDecodeBaseProposer:
...
@@ -802,6 +860,7 @@ class SpecDecodeBaseProposer:
_num_computed_tokens_cpu
=
common_attn_metadata
.
_num_computed_tokens_cpu
,
_num_computed_tokens_cpu
=
common_attn_metadata
.
_num_computed_tokens_cpu
,
num_reqs
=
common_attn_metadata
.
num_reqs
,
num_reqs
=
common_attn_metadata
.
num_reqs
,
num_actual_tokens
=
total_num_tokens
,
num_actual_tokens
=
total_num_tokens
,
num_kv_actual_tokens
=
total_num_tokens
,
max_query_len
=
new_query_len_per_req
.
max
().
item
(),
max_query_len
=
new_query_len_per_req
.
max
().
item
(),
max_seq_len
=
common_attn_metadata
.
seq_lens_cpu
.
max
().
item
(),
max_seq_len
=
common_attn_metadata
.
seq_lens_cpu
.
max
().
item
(),
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
,
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
,
...
@@ -988,6 +1047,104 @@ class SpecDecodeBaseProposer:
...
@@ -988,6 +1047,104 @@ class SpecDecodeBaseProposer:
level_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
1
]
-
total_num_drafts
level_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
1
]
-
total_num_drafts
total_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
1
]
total_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
1
]
return
draft_token_ids_list
return
draft_token_ids_list
def
_prepare_cp_metadata
(
self
,
num_reqs_padded
,
max_query_len
,
max_seq_len
,
num_tokens
,
block_table_gid_0
,
slot_mapping_gid_0
,
query_start_loc
,
query_start_loc_cpu
,
seq_lens
,
seq_lens_cpu
,
num_computed_tokens_cpu
,
):
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
tp_rank
=
get_tensor_model_parallel_rank
()
cp_common_metadata
=
CpCommonAttentionMetadata
(
query_start_loc
=
query_start_loc
.
clone
(),
query_start_loc_cpu
=
query_start_loc_cpu
.
clone
(),
seq_lens
=
seq_lens
.
clone
(),
_seq_lens_cpu
=
seq_lens_cpu
.
clone
(),
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
num_reqs
=
num_reqs_padded
,
req_ids
=
self
.
runner
.
input_batch
.
req_ids
,
num_actual_tokens
=
num_tokens
,
num_kv_actual_tokens
=
num_tokens
,
block_table_tensor
=
block_table_gid_0
,
slot_mapping
=
slot_mapping_gid_0
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
)
q_lens_cpu
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
kv_lens_cpu
=
seq_lens_cpu
total_q_len
=
num_tokens
total_kv_len
=
num_tokens
(
total_q_len
,
q_lens_cpu
,
seq_count
,
kv_lens_cpu
,
local_req_ids
,
scatter_indexes_tensor
,
gather_indexes_tensor
,
seq_indexes_list
,
)
=
self
.
runner
.
_distribute_tokens_to_cp_ranks
(
total_q_len
,
q_lens_cpu
,
kv_lens_cpu
,
tp_rank
,
tp_size
,
self
.
runner
.
input_batch
.
req_ids
,
)
num_reqs
=
seq_count
cu_num_tokens
=
np
.
cumsum
(
q_lens_cpu
)
self
.
query_start_loc
.
np
[
0
]
=
0
self
.
query_start_loc
.
np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
self
.
query_start_loc
.
np
[
num_reqs
+
1
:].
fill
(
cu_num_tokens
[
-
1
])
self
.
query_start_loc
.
copy_to_gpu
()
q_acc_lens
=
self
.
query_start_loc
.
gpu
[:
num_reqs
+
1
]
q_acc_lens_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs
+
1
]
max_q_len
=
max
(
q_acc_lens_cpu
)
self
.
seq_lens
.
np
[:
num_reqs
]
=
kv_lens_cpu
self
.
seq_lens
.
np
[
num_reqs
:].
fill
(
0
)
self
.
seq_lens
.
copy_to_gpu
()
kv_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs
]
kv_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs
]
max_kv_len
=
max
(
kv_lens_cpu
)
num_computed_tokens_cpu
=
kv_lens_cpu
-
q_acc_lens_cpu
[
1
:]
blk_table_tensor
=
block_table_gid_0
[
seq_indexes_list
]
cm_base
=
CommonAttentionMetadata
(
query_start_loc
=
q_acc_lens
,
query_start_loc_cpu
=
q_acc_lens_cpu
,
seq_lens
=
kv_lens
,
_seq_lens_cpu
=
kv_lens_cpu
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_q_len
,
max_query_len
=
max_q_len
,
max_seq_len
=
max_kv_len
,
block_table_tensor
=
blk_table_tensor
,
slot_mapping
=
slot_mapping_gid_0
,
causal
=
True
,
num_kv_actual_tokens
=
total_kv_len
,
seq_indexes_list
=
seq_indexes_list
,
cp_common_metadata
=
cp_common_metadata
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
)
return
cm_base
def
prepare_inputs
(
def
prepare_inputs
(
self
,
self
,
...
...
vllm/v1/worker/block_table.py
View file @
aef3c487
...
@@ -233,6 +233,10 @@ class BlockTable:
...
@@ -233,6 +233,10 @@ class BlockTable:
def
get_device_tensor
(
self
,
num_reqs
:
int
)
->
torch
.
Tensor
:
def
get_device_tensor
(
self
,
num_reqs
:
int
)
->
torch
.
Tensor
:
"""Returns the device tensor of the block table."""
"""Returns the device tensor of the block table."""
return
self
.
block_table
.
gpu
[:
num_reqs
]
return
self
.
block_table
.
gpu
[:
num_reqs
]
def
get_device_tensor_range
(
self
,
start_req
:
int
,
end_req
:
int
)
->
torch
.
Tensor
:
"""Returns the device tensor of the block table."""
return
self
.
block_table
.
gpu
[
start_req
:
end_req
]
def
get_cpu_tensor
(
self
)
->
torch
.
Tensor
:
def
get_cpu_tensor
(
self
)
->
torch
.
Tensor
:
"""Returns the CPU tensor of the block table."""
"""Returns the CPU tensor of the block table."""
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
aef3c487
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment