Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
26645e58
Commit
26645e58
authored
Apr 02, 2026
by
王敏
Browse files
[feat]基于mla sp实现pcp
parent
d1fd831b
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
530 additions
and
34 deletions
+530
-34
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+14
-0
vllm/forward_context.py
vllm/forward_context.py
+11
-0
vllm/model_executor/layers/mla.py
vllm/model_executor/layers/mla.py
+23
-0
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+1
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+118
-4
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+21
-0
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+5
-2
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+2
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+1
-0
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+4
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+330
-27
No files found.
vllm/engine/arg_utils.py
View file @
26645e58
...
...
@@ -1500,6 +1500,20 @@ class EngineArgs:
data_parallel_external_lb
=
(
self
.
data_parallel_external_lb
or
self
.
data_parallel_rank
is
not
None
)
if
(
envs
.
VLLM_MLA_CP
and
self
.
max_num_batched_tokens
is
not
None
and
self
.
max_num_batched_tokens
<
self
.
tensor_parallel_size
**
3
):
raise
ValueError
(
"max_num_batched_tokens should be larger than "
"tensor_parallel_size ** 3 when enabled VLLM_MLA_CP"
)
logger
.
info
(
"[MLACP] VLLM_MLA_CP is %s"
,
envs
.
VLLM_MLA_CP
)
logger
.
info
(
"[MLACP] VLLM_MLA_CPLB is %s"
,
envs
.
VLLM_MLA_CPLB
)
# Local DP rank = 1, use pure-external LB.
if
data_parallel_external_lb
:
assert
self
.
data_parallel_rank
is
not
None
,
(
...
...
vllm/forward_context.py
View file @
26645e58
...
...
@@ -240,6 +240,9 @@ class ForwardContext:
additional_kwargs
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
def
__post_init__
(
self
):
assert
self
.
cudagraph_runtime_mode
.
valid_runtime_modes
(),
(
f
"Invalid cudagraph runtime mode:
{
self
.
cudagraph_runtime_mode
}
"
...
...
@@ -273,6 +276,8 @@ def create_forward_context(
slot_mapping
:
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
,
additional_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
skip_compiled
:
bool
=
False
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
):
if
vllm_config
.
compilation_config
.
fast_moe_cold_start
:
if
vllm_config
.
speculative_config
is
None
:
...
...
@@ -298,6 +303,8 @@ def create_forward_context(
batch_descriptor
=
batch_descriptor
,
ubatch_slices
=
ubatch_slices
,
skip_compiled
=
skip_compiled
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
additional_kwargs
=
additional_kwargs
or
{},
)
...
...
@@ -329,6 +336,8 @@ def set_forward_context(
ubatch_slices
:
UBatchSlices
|
None
=
None
,
slot_mapping
:
dict
[
str
,
torch
.
Tensor
]
|
list
[
dict
[
str
,
torch
.
Tensor
]]
|
None
=
None
,
skip_compiled
:
bool
=
False
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
...
...
@@ -389,6 +398,8 @@ def set_forward_context(
slot_mapping
,
additional_kwargs
,
skip_compiled
,
scatter_indexes_tensor
,
gather_indexes_tensor
,
)
try
:
...
...
vllm/model_executor/layers/mla.py
View file @
26645e58
...
...
@@ -9,6 +9,9 @@ from vllm.config import CacheConfig
import
vllm.envs
as
envs
from
vllm.model_executor.custom_op
import
PluggableLayer
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
)
@
dataclass
...
...
@@ -183,8 +186,19 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if
llama_4_scaling
is
not
None
:
q
*=
llama_4_scaling
enable_mla_cp
=
envs
.
VLLM_MLA_CP
# and not get_forward_context().draft_model
# if not use_fused_rms_rope_concat:
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
if
enable_mla_cp
:
kv_c_normed
=
tensor_model_parallel_all_gather
(
kv_c_normed
.
contiguous
(),
0
)
k_pe
=
tensor_model_parallel_all_gather
(
k_pe
.
contiguous
(),
0
)
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
...
...
@@ -220,6 +234,15 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires rotary_emb to "
"expose 'cos_sin_cache'."
)
if
enable_mla_cp
:
kv_c
=
tensor_model_parallel_all_gather
(
kv_c
.
contiguous
(),
0
)
k_pe
=
tensor_model_parallel_all_gather
(
k_pe
.
contiguous
(),
0
)
attn_out
=
self
.
mla_attn
(
q
[...,
self
.
qk_nope_head_dim
:],
kv_c
,
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
26645e58
...
...
@@ -71,7 +71,7 @@ def sparse_attn_indexer(
)
attn_metadata
=
attn_metadata
[
k_cache_prefix
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
slot_mapping
=
attn_metadata
.
slot_mapping
[:
attn_metadata
.
num_kv_actual_tokens
]
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
26645e58
...
...
@@ -46,6 +46,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
)
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
,
tensor_model_parallel_reduce_scatter
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
...
...
@@ -211,10 +212,82 @@ class DeepseekV2MLP(nn.Module):
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
#reduce_results=reduce_results,
reduce_results
=
False
,
disable_tp
=
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
):
enable_mla_cp
=
envs
.
VLLM_MLA_CP
# and not get_forward_context().draft_model
if
enable_mla_cp
:
x
=
tensor_model_parallel_all_gather
(
x
.
contiguous
(),
0
)
if
envs
.
USE_FUSED_RMS_QUANT
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
iqis
=
iqis
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
from
lmslim.quantize.quant_ops
import
lm_fuse_silu_mul_quant
xq
,
xs
=
lm_fuse_silu_mul_quant
(
gate_up
)
x
,
_
=
self
.
down_proj
(
gate_up
,
iqis
=
(
xq
,
xs
))
else
:
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
else
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
if
enable_mla_cp
:
x
=
tensor_model_parallel_reduce_scatter
(
x
.
contiguous
(),
dim
=
0
)
elif
self
.
tp_size
>
1
:
x
=
tensor_model_parallel_all_reduce
(
x
)
return
x
class
DeepseekV2SharedMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
QuantizationConfig
|
None
=
None
,
reduce_results
:
bool
=
True
,
is_sequence_parallel
=
False
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
disable_tp
=
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
disable_tp
=
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
...
...
@@ -311,7 +384,7 @@ class DeepseekV2MoE(nn.Module):
else
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
DeepseekV2MLP
(
self
.
shared_experts
=
DeepseekV2
Shared
MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
...
...
@@ -357,6 +430,11 @@ class DeepseekV2MoE(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
enable_mla_cp
=
envs
.
VLLM_MLA_CP
#and not get_forward_context().draft_model
if
enable_mla_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
0
)
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
...
...
@@ -428,7 +506,12 @@ class DeepseekV2MoE(nn.Module):
assert
shared_output
is
not
None
final_hidden_states
+=
shared_output
if
self
.
is_sequence_parallel
:
if
enable_mla_cp
:
final_hidden_states
=
tensor_model_parallel_reduce_scatter
(
final_hidden_states
.
contiguous
(),
0
)
return
final_hidden_states
elif
self
.
is_sequence_parallel
:
final_hidden_states
=
tensor_model_parallel_all_gather
(
final_hidden_states
,
0
)
...
...
@@ -756,6 +839,12 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
enable_mla_cp
=
envs
.
VLLM_MLA_CP
# and not get_forward_context().draft_model
if
enable_mla_cp
:
k
=
tensor_model_parallel_all_gather
(
k
.
contiguous
(),
0
)
# we only quant q here since k quant is fused with cache insertion
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
q
=
q
.
view
(
-
1
,
self
.
head_dim
)
...
...
@@ -819,7 +908,8 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
num_heads
=
num_heads
tp_size
=
get_tensor_model_parallel_world_size
()
assert
num_heads
%
tp_size
==
0
self
.
num_local_heads
=
num_heads
//
tp_size
#self.num_local_heads = num_heads // tp_size
self
.
num_local_heads
=
num_heads
//
tp_size
if
not
envs
.
VLLM_MLA_CP
else
self
.
num_heads
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
max_position_embeddings
=
max_position_embeddings
...
...
@@ -853,6 +943,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_b_proj"
,
disable_tp
=
envs
.
VLLM_MLA_CP
,
)
else
:
self
.
q_proj
=
ColumnParallelLinear
(
...
...
@@ -861,6 +952,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_proj"
,
disable_tp
=
envs
.
VLLM_MLA_CP
,
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
...
...
@@ -869,6 +961,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.kv_b_proj"
,
disable_tp
=
envs
.
VLLM_MLA_CP
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
v_head_dim
,
...
...
@@ -876,6 +969,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
disable_tp
=
envs
.
VLLM_MLA_CP
,
)
if
config
.
rope_parameters
[
"rope_type"
]
!=
"default"
:
...
...
@@ -1217,6 +1311,9 @@ class DeepseekV2Model(nn.Module):
self
.
config
=
config
self
.
device
=
current_platform
.
device_type
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
vocab_size
=
config
.
vocab_size
#添加判断,默认开启DSA
force_disable_dsa
=
os
.
environ
.
get
(
"VLLM_DISABLE_DSA"
,
"0"
)
==
"1"
...
...
@@ -1279,6 +1376,19 @@ class DeepseekV2Model(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
enable_mla_cp
=
envs
.
VLLM_MLA_CP
# and not get_forward_context().draft_model
if
enable_mla_cp
:
hidden_states_per_rank
=
torch
.
chunk
(
hidden_states
,
chunks
=
self
.
tp_size
,
dim
=
0
)
hidden_states
=
hidden_states_per_rank
[
self
.
tp_rank
].
contiguous
()
if
residual
is
not
None
:
residual_per_rank
=
torch
.
chunk
(
residual
,
chunks
=
self
.
tp_size
,
dim
=
0
)
residual
=
residual_per_rank
[
self
.
tp_rank
].
contiguous
()
if
positions
is
not
None
:
positions_per_rank
=
torch
.
chunk
(
positions
,
chunks
=
self
.
tp_size
,
dim
=
0
)
positions
=
positions_per_rank
[
self
.
tp_rank
].
contiguous
()
# Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config
=
getattr
(
self
.
config
,
"llama_4_scaling"
,
None
)
llama_4_scaling
:
torch
.
Tensor
|
None
...
...
@@ -1304,6 +1414,10 @@ class DeepseekV2Model(nn.Module):
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
enable_mla_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
return
hidden_states
...
...
vllm/v1/attention/backend.py
View file @
26645e58
...
...
@@ -285,6 +285,18 @@ class AttentionMetadata:
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadata
)
@
dataclass
class
CpCommonAttentionMetadata
:
# sp related metadata
query_start_loc
:
torch
.
Tensor
query_start_loc_cpu
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
_seq_lens_cpu
:
torch
.
Tensor
num_actual_tokens
:
int
max_query_len
:
int
num_reqs
:
int
req_ids
:
list
[
str
]
@
dataclass
class
CommonAttentionMetadata
:
...
...
@@ -306,6 +318,7 @@ class CommonAttentionMetadata:
"""Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens
:
int
"""Total number of tokens in batch"""
max_query_len
:
int
"""Longest query in batch"""
...
...
@@ -315,6 +328,14 @@ class CommonAttentionMetadata:
block_table_tensor
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
num_kv_actual_tokens
:
int
seq_indexes_list
:
list
[
int
]
|
None
=
None
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
cp_common_metadata
:
CpCommonAttentionMetadata
|
None
=
None
enable_mla_cp
:
bool
=
False
causal
:
bool
=
True
# Needed by FastPrefillAttentionBuilder
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
26645e58
...
...
@@ -138,6 +138,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
max_seq_len
:
int
num_actual_tokens
:
int
# Number of tokens excluding padding.
num_kv_actual_tokens
:
int
query_start_loc
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
...
...
@@ -693,6 +694,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
max_query_len
=
cm
.
max_query_len
,
max_seq_len
=
cm
.
max_seq_len
,
num_actual_tokens
=
cm
.
num_actual_tokens
,
num_kv_actual_tokens
=
cm
.
num_kv_actual_tokens
,
query_start_loc
=
cm
.
query_start_loc
,
slot_mapping
=
cm
.
slot_mapping
,
block_table
=
cm
.
block_table_tensor
,
...
...
@@ -1024,12 +1026,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
return
output
.
fill_
(
0
)
num_actual_toks
=
attn_metadata
.
num_actual_tokens
num_kv_actual_toks
=
attn_metadata
.
num_kv_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q
=
q
[:
num_actual_toks
,
...]
k_c_normed
=
k_c_normed
[:
num_actual_toks
,
...]
k_pe
=
k_pe
[:
num_actual_toks
,
...]
k_c_normed
=
k_c_normed
[:
num_
kv_
actual_toks
,
...]
k_pe
=
k_pe
[:
num_
kv_
actual_toks
,
...]
assert
self
.
topk_indices_buffer
is
not
None
topk_indices
=
self
.
topk_indices_buffer
[:
num_actual_toks
]
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
26645e58
...
...
@@ -105,6 +105,7 @@ class DeepseekV32IndexerMetadata:
max_seq_len
:
int
num_actual_tokens
:
int
# Number of tokens excluding padding.
num_kv_actual_tokens
:
int
query_start_loc
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
# The dimension of the attention heads
...
...
@@ -437,6 +438,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
max_query_len
=
common_attn_metadata
.
max_query_len
,
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
,
num_kv_actual_tokens
=
common_attn_metadata
.
num_kv_actual_tokens
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
head_dim
=
128
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
26645e58
...
...
@@ -802,6 +802,7 @@ class SpecDecodeBaseProposer:
_num_computed_tokens_cpu
=
common_attn_metadata
.
_num_computed_tokens_cpu
,
num_reqs
=
common_attn_metadata
.
num_reqs
,
num_actual_tokens
=
total_num_tokens
,
num_kv_actual_tokens
=
total_num_tokens
,
max_query_len
=
new_query_len_per_req
.
max
().
item
(),
max_seq_len
=
common_attn_metadata
.
seq_lens_cpu
.
max
().
item
(),
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
,
...
...
vllm/v1/worker/block_table.py
View file @
26645e58
...
...
@@ -233,6 +233,10 @@ class BlockTable:
def
get_device_tensor
(
self
,
num_reqs
:
int
)
->
torch
.
Tensor
:
"""Returns the device tensor of the block table."""
return
self
.
block_table
.
gpu
[:
num_reqs
]
def
get_device_tensor_range
(
self
,
start_req
:
int
,
end_req
:
int
)
->
torch
.
Tensor
:
"""Returns the device tensor of the block table."""
return
self
.
block_table
.
gpu
[
start_req
:
end_req
]
def
get_cpu_tensor
(
self
)
->
torch
.
Tensor
:
"""Returns the CPU tensor of the block table."""
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
26645e58
...
...
@@ -42,8 +42,13 @@ from vllm.distributed.parallel_state import (
get_tp_group
,
graph_capture
,
is_global_first_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
prepare_communication_buffer_for_model
,
)
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
)
from
vllm.forward_context
import
(
BatchDescriptor
,
set_forward_context
,
...
...
@@ -104,6 +109,7 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder
,
AttentionType
,
CommonAttentionMetadata
,
CpCommonAttentionMetadata
,
MultipleOf
,
)
from
vllm.v1.attention.backends.gdn_attn
import
GDNAttentionMetadataBuilder
...
...
@@ -371,10 +377,16 @@ class GPUModelRunner(
# Always set to false after the first forward pass
self
.
calculate_kv_scales
=
self
.
cache_config
.
calculate_kv_scales
self
.
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
self
.
dcp_world_size
=
self
.
parallel_config
.
decode_context_parallel_size
self
.
dcp_rank
=
0
if
self
.
dcp_world_size
<=
1
else
get_dcp_group
().
rank_in_group
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
#self.max_num_reqs = scheduler_config.max_num_seqs
self
.
max_num_reqs
=
(
scheduler_config
.
max_num_seqs
if
not
envs
.
VLLM_MLA_CPLB
else
scheduler_config
.
max_num_seqs
*
2
)
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
...
...
@@ -1485,6 +1497,236 @@ class GPUModelRunner(
return
encoder_seq_lens
,
encoder_seq_lens_cpu
def
_distribute_tokens_to_cp_ranks
(
self
,
total_q_len
:
int
,
q_lens_cpu
:
np
.
ndarray
,
kv_lens_cpu
:
np
.
ndarray
,
tp_rank
:
int
,
tp_size
:
int
,
req_ids
:
list
[
str
],
):
tokens_per_rank
=
(
total_q_len
+
tp_size
-
1
)
//
tp_size
start_token
=
tp_rank
*
tokens_per_rank
end_token
=
min
((
tp_rank
+
1
)
*
tokens_per_rank
,
total_q_len
)
q_lens
=
[]
seq_count
=
0
seq_indexes
=
[]
kv_lens
=
[]
local_req_ids
=
[]
local_scatter_indexes_tensor
=
None
gather_indexes_tensor
=
None
if
envs
.
VLLM_MLA_CPLB
:
rank_tokens
=
0
rank_pad_tokens
=
0
accu_q_start
=
0
scatter_indexes
:
list
[
int
]
=
[]
num_requests
=
len
(
q_lens_cpu
)
for
i
in
range
(
num_requests
):
req_q_len
=
q_lens_cpu
[
i
]
req_pad_q_len
=
round_up
(
q_lens_cpu
[
i
],
2
*
tp_size
)
kv_len
=
kv_lens_cpu
[
i
]
chunk_q_len
=
req_pad_q_len
//
(
2
*
tp_size
)
q_1_start
=
tp_rank
*
chunk_q_len
q_1_end
=
(
tp_rank
+
1
)
*
chunk_q_len
q_2_start
=
req_pad_q_len
-
(
tp_rank
+
1
)
*
chunk_q_len
q_2_end
=
req_pad_q_len
-
tp_rank
*
chunk_q_len
q_len_1
=
(
chunk_q_len
if
q_1_end
<=
req_q_len
else
max
(
0
,
req_q_len
-
q_1_start
)
)
q_len_2
=
(
chunk_q_len
if
q_2_end
<=
req_q_len
else
max
(
0
,
req_q_len
-
q_2_start
)
)
kv_len_1
=
kv_len
-
req_q_len
+
min
(
req_q_len
,
q_1_end
)
kv_len_2
=
kv_len
-
req_q_len
+
min
(
req_q_len
,
q_2_end
)
scatter_index1
=
range
(
accu_q_start
+
q_1_start
,
accu_q_start
+
q_1_start
+
q_len_1
)
scatter_index2
=
range
(
accu_q_start
+
q_2_start
,
accu_q_start
+
q_2_start
+
q_len_2
)
accu_q_start
+=
req_q_len
if
q_len_1
>
0
:
q_lens
.
append
(
q_len_1
)
kv_lens
.
append
(
kv_len_1
)
seq_indexes
.
append
(
i
)
local_req_ids
.
append
(
req_ids
[
i
])
scatter_indexes
.
extend
(
scatter_index1
)
seq_count
+=
1
rank_tokens
+=
q_len_1
if
q_len_2
>
0
:
q_lens
.
append
(
q_len_2
)
kv_lens
.
append
(
kv_len_2
)
seq_indexes
.
append
(
i
)
local_req_ids
.
append
(
req_ids
[
i
])
scatter_indexes
.
extend
(
scatter_index2
)
seq_count
+=
1
rank_tokens
+=
q_len_2
rank_pad_tokens
+=
chunk_q_len
*
2
if
len
(
scatter_indexes
)
<
rank_pad_tokens
:
scatter_indexes
.
extend
([
-
1
]
*
(
rank_pad_tokens
-
len
(
scatter_indexes
)))
local_scatter_indexes_tensor
=
torch
.
tensor
(
scatter_indexes
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
global_scatter_indexes_tensor
=
tensor_model_parallel_all_gather
(
local_scatter_indexes_tensor
.
contiguous
(),
dim
=
0
)
non_neg_mask
=
global_scatter_indexes_tensor
!=
-
1
non_neg_values
=
global_scatter_indexes_tensor
[
non_neg_mask
]
non_neg_positions
=
torch
.
where
(
non_neg_mask
)[
0
]
sorted_indices
=
torch
.
argsort
(
non_neg_values
)
gather_indexes_tensor
=
non_neg_positions
[
sorted_indices
]
if
isinstance
(
rank_tokens
,
torch
.
Tensor
):
rank_tokens
=
rank_tokens
.
item
()
else
:
current_seq
=
0
current_pos
=
0
rank_tokens
=
min
(
tokens_per_rank
,
end_token
-
start_token
)
while
start_token
<
end_token
and
current_seq
<
len
(
q_lens_cpu
):
q_len
=
q_lens_cpu
[
current_seq
]
q_start
=
current_pos
q_end
=
current_pos
+
q_len
kv_len
=
kv_lens_cpu
[
current_seq
]
# Find overlap between this sequence and rank's token range
overlap_start
=
max
(
start_token
,
q_start
)
overlap_end
=
min
(
end_token
,
q_end
)
if
overlap_start
<
overlap_end
:
# This sequence contributes tokens to this rank
token_count
=
overlap_end
-
overlap_start
q_lens
.
append
(
token_count
)
start_token
=
overlap_end
seq_count
+=
1
seq_indexes
.
append
(
current_seq
)
local_req_ids
.
append
(
req_ids
[
current_seq
])
if
q_end
<=
end_token
:
kv_lens
.
append
(
kv_len
)
else
:
kv_lens
.
append
(
kv_len
-
(
q_end
-
end_token
))
current_pos
=
q_end
current_seq
+=
1
return
(
rank_tokens
,
np
.
array
(
q_lens
,
dtype
=
np
.
int32
),
seq_count
,
np
.
array
(
kv_lens
,
dtype
=
np
.
int32
),
np
.
array
(
local_req_ids
,
dtype
=
str
),
local_scatter_indexes_tensor
,
gather_indexes_tensor
,
seq_indexes
,
)
def
_prepare_cp_metadata
(
self
,
num_reqs_padded
,
max_query_len
,
num_tokens
,
block_table_gid_0
,
slot_mapping_gid_0
,
):
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
tp_rank
=
get_tensor_model_parallel_rank
()
cp_common_metadata
=
CpCommonAttentionMetadata
(
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
].
clone
(),
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
].
clone
(),
seq_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs_padded
].
clone
(),
_seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs_padded
].
clone
(),
max_query_len
=
max_query_len
,
num_reqs
=
num_reqs_padded
,
req_ids
=
self
.
input_batch
.
req_ids
,
num_actual_tokens
=
num_tokens
,
)
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
]
q_lens_cpu
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
kv_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs_padded
]
total_q_len
=
num_tokens
total_kv_len
=
num_tokens
(
total_q_len
,
q_lens_cpu
,
seq_count
,
kv_lens_cpu
,
local_req_ids
,
scatter_indexes_tensor
,
gather_indexes_tensor
,
seq_indexes_list
,
)
=
self
.
_distribute_tokens_to_cp_ranks
(
total_q_len
,
q_lens_cpu
,
kv_lens_cpu
,
tp_rank
,
tp_size
,
self
.
input_batch
.
req_ids
,
)
num_reqs
=
seq_count
cu_num_tokens
=
np
.
cumsum
(
q_lens_cpu
)
self
.
query_start_loc
.
np
[
0
]
=
0
self
.
query_start_loc
.
np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
self
.
query_start_loc
.
np
[
num_reqs
+
1
:].
fill
(
cu_num_tokens
[
-
1
])
self
.
query_start_loc
.
copy_to_gpu
()
q_acc_lens
=
self
.
query_start_loc
.
gpu
[:
num_reqs
+
1
]
q_acc_lens_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs
+
1
]
max_q_len
=
max
(
q_acc_lens_cpu
)
self
.
seq_lens
.
np
[:
num_reqs
]
=
kv_lens_cpu
self
.
seq_lens
.
np
[
num_reqs
:].
fill
(
0
)
self
.
seq_lens
.
copy_to_gpu
()
kv_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs
]
kv_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs
]
max_kv_len
=
max
(
kv_lens_cpu
)
num_computed_tokens_cpu
=
kv_lens_cpu
-
q_acc_lens_cpu
[
1
:]
blk_table_tensor
=
block_table_gid_0
[
seq_indexes_list
]
cm_base
=
CommonAttentionMetadata
(
query_start_loc
=
q_acc_lens
,
query_start_loc_cpu
=
q_acc_lens_cpu
,
seq_lens
=
kv_lens
,
_seq_lens_cpu
=
kv_lens_cpu
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_q_len
,
max_query_len
=
max_q_len
,
max_seq_len
=
max_kv_len
,
block_table_tensor
=
blk_table_tensor
,
slot_mapping
=
slot_mapping_gid_0
,
causal
=
True
,
num_kv_actual_tokens
=
total_kv_len
,
seq_indexes_list
=
seq_indexes_list
,
cp_common_metadata
=
cp_common_metadata
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
)
return
cm_base
def
_prepare_inputs
(
self
,
scheduler_output
:
"SchedulerOutput"
,
...
...
@@ -1718,13 +1960,20 @@ class GPUModelRunner(
num_scheduled_tokens
:
dict
[
str
,
int
]
|
None
=
None
,
cascade_attn_prefix_lens
:
list
[
list
[
int
]]
|
None
=
None
,
slot_mappings
:
dict
[
int
,
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
PerLayerAttnMetadata
,
CommonAttentionMetadata
|
None
]:
)
->
tuple
[
PerLayerAttnMetadata
,
CommonAttentionMetadata
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
]:
"""
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
"""
# Attention metadata is not needed for attention free models
if
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
0
:
return
{},
None
return
{},
None
,
None
,
None
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
num_tokens_padded
=
num_tokens_padded
or
num_tokens
num_reqs_padded
=
num_reqs_padded
or
num_reqs
...
...
@@ -1772,25 +2021,40 @@ class GPUModelRunner(
assert
slot_mappings
is
not
None
block_table_gid_0
=
_get_block_table
(
0
)
slot_mapping_gid_0
=
slot_mappings
[
0
]
scatter_indexes_tensor
=
None
gather_indexes_tensor
=
None
if
self
.
model_config
.
enable_return_routed_experts
:
self
.
slot_mapping
=
slot_mapping_gid_0
[:
num_tokens
].
cpu
().
numpy
()
cm_base
=
CommonAttentionMetadata
(
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
],
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
],
seq_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs_padded
],
_seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs_padded
],
_num_computed_tokens_cpu
=
self
.
input_batch
.
num_computed_tokens_cpu_tensor
[
:
num_reqs_padded
],
num_reqs
=
num_reqs_padded
,
num_actual_tokens
=
num_tokens_padded
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
block_table_tensor
=
block_table_gid_0
,
slot_mapping
=
slot_mapping_gid_0
,
causal
=
True
,
)
if
not
envs
.
VLLM_MLA_CP
or
num_tokens
<=
tp_size
*
tp_size
:
cm_base
=
CommonAttentionMetadata
(
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
],
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
],
seq_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs_padded
],
_seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs_padded
],
_num_computed_tokens_cpu
=
self
.
input_batch
.
num_computed_tokens_cpu_tensor
[
:
num_reqs_padded
],
num_reqs
=
num_reqs_padded
,
num_actual_tokens
=
num_tokens_padded
,
num_kv_actual_tokens
=
num_tokens_padded
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
block_table_tensor
=
block_table_gid_0
,
slot_mapping
=
slot_mapping_gid_0
,
causal
=
True
,
)
else
:
cm_base
=
self
.
_prepare_cp_metadata
(
num_reqs_padded
,
max_query_len
,
num_tokens
,
block_table_gid_0
,
slot_mapping_gid_0
,
)
scatter_indexes_tensor
=
cm_base
.
scatter_indexes_tensor
gather_indexes_tensor
=
cm_base
.
gather_indexes_tensor
if
self
.
dcp_world_size
>
1
:
self
.
dcp_local_seq_lens
.
cpu
[:
num_reqs
]
=
get_dcp_local_seq_lens
(
...
...
@@ -1901,6 +2165,9 @@ class GPUModelRunner(
cm
.
block_table_tensor
=
_get_block_table
(
kv_cache_gid
)
cm
.
slot_mapping
=
slot_mappings
[
kv_cache_gid
]
if
cm
.
seq_indexes_list
is
not
None
:
cm
.
block_table_tensor
=
cm
.
block_table_tensor
[
cm
.
seq_indexes_list
]
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
and
hasattr
(
self
,
"drafter"
):
if
isinstance
(
self
.
drafter
,
EagleProposer
):
if
self
.
drafter
.
attn_layer_names
[
0
]
in
kv_cache_group
.
layer_names
:
...
...
@@ -1936,8 +2203,10 @@ class GPUModelRunner(
for
_metadata
in
attn_metadata
.
values
():
_metadata
.
mm_prefix_range
=
req_doc_ranges
# type: ignore[attr-defined]
if
spec_decode_common_attn_metadata
is
not
None
and
(
num_reqs
!=
num_reqs_padded
or
num_tokens
!=
num_tokens_padded
if
(
(
not
envs
.
VLLM_MLA_CP
)
and
spec_decode_common_attn_metadata
is
not
None
and
(
num_reqs
!=
num_reqs_padded
or
num_tokens
!=
num_tokens_padded
)
):
# Currently the drafter still only uses piecewise cudagraphs (and modifies
# the attention metadata in directly), and therefore does not want to use
...
...
@@ -1946,7 +2215,12 @@ class GPUModelRunner(
spec_decode_common_attn_metadata
.
unpadded
(
num_tokens
,
num_reqs
)
)
return
attn_metadata
,
spec_decode_common_attn_metadata
return
(
attn_metadata
,
spec_decode_common_attn_metadata
,
scatter_indexes_tensor
,
gather_indexes_tensor
)
def
_compute_cascade_attn_prefix_lens
(
self
,
...
...
@@ -2798,9 +3072,19 @@ class GPUModelRunner(
return
model_runner_output
def
_pad_for_mla_cp
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
num_scheduled_tokens
<=
tp_size
*
tp_size
:
return
num_scheduled_tokens
*
tp_size
else
:
return
round_up
(
num_scheduled_tokens
,
tp_size
)
def
_pad_for_sequence_parallelism
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
if
envs
.
VLLM_MLA_CP
:
return
self
.
_pad_for_mla_cp
(
num_scheduled_tokens
)
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
self
.
compilation_config
.
pass_config
.
enable_sp
and
tp_size
>
1
:
return
round_up
(
num_scheduled_tokens
,
tp_size
)
...
...
@@ -3497,6 +3781,8 @@ class GPUModelRunner(
)
num_tokens_padded
=
batch_desc
.
num_tokens
if
envs
.
VLLM_MLA_CP
:
num_tokens_padded
=
self
.
_pad_for_mla_cp
(
num_tokens_unpadded
)
num_reqs_padded
=
(
batch_desc
.
num_reqs
if
batch_desc
.
num_reqs
is
not
None
else
num_reqs
)
...
...
@@ -3553,8 +3839,12 @@ class GPUModelRunner(
ubatch_slices
=
ubatch_slices_padded
,
)
attn_metadata
,
spec_decode_common_attn_metadata
=
(
self
.
_build_attention_metadata
(
(
attn_metadata
,
spec_decode_common_attn_metadata
,
scatter_indexes_tensor
,
gather_indexes_tensor
,
)
=
self
.
_build_attention_metadata
(
num_tokens
=
num_tokens_unpadded
,
num_tokens_padded
=
num_tokens_padded
if
pad_attn
else
None
,
num_reqs
=
num_reqs
,
...
...
@@ -3567,7 +3857,6 @@ class GPUModelRunner(
cascade_attn_prefix_lens
=
cascade_attn_prefix_lens
,
slot_mappings
=
slot_mappings_by_group
,
)
)
(
input_ids
,
...
...
@@ -3608,6 +3897,8 @@ class GPUModelRunner(
ubatch_slices
=
ubatch_slices_padded
,
slot_mapping
=
slot_mappings
,
skip_compiled
=
has_encoder_input
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
),
record_function_or_nullcontext
(
"gpu_model_runner: forward"
),
self
.
maybe_get_kv_connector_output
(
scheduler_output
)
as
kv_connector_output
,
...
...
@@ -4094,7 +4385,16 @@ class GPUModelRunner(
spec_decode_metadata
,
valid_sampled_tokens_count
,
)
total_num_tokens
=
common_attn_metadata
.
num_actual_tokens
#total_num_tokens = common_attn_metadata.num_actual_tokens
if
(
envs
.
VLLM_MLA_CP
and
common_attn_metadata
.
cp_common_metadata
is
not
None
):
total_num_tokens
=
(
common_attn_metadata
.
cp_common_metadata
.
num_actual_tokens
)
else
:
total_num_tokens
=
common_attn_metadata
.
num_actual_tokens
# When padding the batch, token_indices is just a range
target_token_ids
=
self
.
input_ids
.
gpu
[:
total_num_tokens
]
target_positions
=
self
.
_get_positions
(
total_num_tokens
)
...
...
@@ -4618,6 +4918,9 @@ class GPUModelRunner(
or
cudagraph_runtime_mode
.
valid_runtime_modes
()
)
if
envs
.
VLLM_MLA_CP
:
num_tokens
=
max
(
self
.
tp_size
,
num_tokens
)
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs.
...
...
@@ -4748,7 +5051,7 @@ class GPUModelRunner(
self
.
query_start_loc
.
copy_to_gpu
()
pad_attn
=
cudagraph_runtime_mode
==
CUDAGraphMode
.
FULL
attn_metadata
,
_
=
self
.
_build_attention_metadata
(
attn_metadata
,
_
,
_
,
_
=
self
.
_build_attention_metadata
(
num_tokens
=
num_tokens_unpadded
,
num_reqs
=
num_reqs_padded
,
max_query_len
=
max_query_len
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment