Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aef3c487
Commit
aef3c487
authored
Apr 22, 2026
by
wangmin6
Committed by
zhangzbb
Apr 22, 2026
Browse files
[Feature]添加PCP功能,只支持mla架构,CPLB待验证
parent
c1819454
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
886 additions
and
47 deletions
+886
-47
vllm/config/parallel.py
vllm/config/parallel.py
+7
-0
vllm/config/vllm.py
vllm/config/vllm.py
+6
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+14
-0
vllm/envs.py
vllm/envs.py
+8
-1
vllm/forward_context.py
vllm/forward_context.py
+21
-0
vllm/model_executor/layers/mla.py
vllm/model_executor/layers/mla.py
+36
-0
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+1
-1
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+37
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+186
-7
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+38
-0
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+5
-2
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+2
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+163
-6
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+4
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+358
-29
No files found.
vllm/config/parallel.py
View file @
aef3c487
...
...
@@ -299,6 +299,13 @@ class ParallelConfig:
should only be set by API server scale-out.
"""
enable_lightly_cp
:
bool
=
False
"""Use lightly context parallel."""
enable_lightly_cplb
:
bool
=
False
"""Use lightly context parallel load balancing."""
@
field_validator
(
"disable_nccl_for_dp_synchronization"
,
mode
=
"wrap"
)
@
classmethod
def
_skip_none_validation
(
cls
,
value
:
Any
,
handler
:
Callable
)
->
Any
:
...
...
vllm/config/vllm.py
View file @
aef3c487
...
...
@@ -1061,6 +1061,12 @@ class VllmConfig:
# Handle the KV connector configs
self
.
_post_init_kv_transfer_config
()
if
self
.
parallel_config
.
enable_lightly_cp
and
not
self
.
model_config
.
enforce_eager
:
raise
ValueError
(
"Lightly context parallel currently only supports the eager mode!!!"
)
def
update_sizes_for_sequence_parallelism
(
self
,
possible_sizes
:
list
)
->
list
:
# remove the sizes that not multiple of tp_size when
# enable sequence parallelism
...
...
vllm/engine/arg_utils.py
View file @
aef3c487
...
...
@@ -585,6 +585,9 @@ class EngineArgs:
kv_offloading_backend
:
KVOffloadingBackend
=
CacheConfig
.
kv_offloading_backend
tokens_only
:
bool
=
False
enable_lightly_cp
:
bool
=
ParallelConfig
.
enable_lightly_cp
enable_lightly_cplb
:
bool
=
ParallelConfig
.
enable_lightly_cplb
def
__post_init__
(
self
):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
...
...
@@ -902,6 +905,15 @@ class EngineArgs:
"--worker-extension-cls"
,
**
parallel_kwargs
[
"worker_extension_cls"
]
)
parallel_group
.
add_argument
(
"--enable-lightly-cp"
,
**
parallel_kwargs
[
"enable_lightly_cp"
],
)
parallel_group
.
add_argument
(
"--enable-lightly-cplb"
,
**
parallel_kwargs
[
"enable_lightly_cplb"
],
)
# KV cache arguments
cache_kwargs
=
get_kwargs
(
CacheConfig
)
cache_group
=
parser
.
add_argument_group
(
...
...
@@ -1661,6 +1673,8 @@ class EngineArgs:
cp_kv_cache_interleave_size
=
self
.
cp_kv_cache_interleave_size
,
_api_process_count
=
self
.
_api_process_count
,
_api_process_rank
=
self
.
_api_process_rank
,
enable_lightly_cp
=
self
.
enable_lightly_cp
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
,
)
speculative_config
=
self
.
create_speculative_config
(
...
...
vllm/envs.py
View file @
aef3c487
...
...
@@ -324,6 +324,9 @@ if TYPE_CHECKING:
USE_LIGHTOP_TOPK
:
bool
=
False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX
:
bool
=
False
VLLM_DISABLE_DSA
:
bool
=
False
VLLM_LIGHTLY_CP_THRESHOULD
:
int
=
2048
def
get_default_cache_root
():
return
os
.
getenv
(
"XDG_CACHE_HOME"
,
...
...
@@ -2004,7 +2007,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
#If set to 1/True, disenable the DSA.
"VLLM_DISABLE_DSA"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_DISABLE_DSA"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# MLA_CP open threshold
"VLLM_LIGHTLY_CP_THRESHOULD"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_LIGHTLY_CP_THRESHOULD"
,
"2048"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/forward_context.py
View file @
aef3c487
...
...
@@ -240,6 +240,11 @@ class ForwardContext:
additional_kwargs
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
enable_lightly_cp
:
bool
=
False
enable_lightly_cplb
:
bool
=
False
def
__post_init__
(
self
):
assert
self
.
cudagraph_runtime_mode
.
valid_runtime_modes
(),
(
f
"Invalid cudagraph runtime mode:
{
self
.
cudagraph_runtime_mode
}
"
...
...
@@ -273,6 +278,10 @@ def create_forward_context(
slot_mapping
:
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
,
additional_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
skip_compiled
:
bool
=
False
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
enable_lightly_cp
:
bool
=
False
,
enable_lightly_cplb
:
bool
=
False
):
if
vllm_config
.
compilation_config
.
fast_moe_cold_start
:
if
vllm_config
.
speculative_config
is
None
:
...
...
@@ -298,6 +307,10 @@ def create_forward_context(
batch_descriptor
=
batch_descriptor
,
ubatch_slices
=
ubatch_slices
,
skip_compiled
=
skip_compiled
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
enable_lightly_cp
=
enable_lightly_cp
,
enable_lightly_cplb
=
enable_lightly_cplb
,
additional_kwargs
=
additional_kwargs
or
{},
)
...
...
@@ -329,6 +342,10 @@ def set_forward_context(
ubatch_slices
:
UBatchSlices
|
None
=
None
,
slot_mapping
:
dict
[
str
,
torch
.
Tensor
]
|
list
[
dict
[
str
,
torch
.
Tensor
]]
|
None
=
None
,
skip_compiled
:
bool
=
False
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
enable_lightly_cp
:
bool
=
False
,
enable_lightly_cplb
:
bool
=
False
,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
...
...
@@ -390,6 +407,10 @@ def set_forward_context(
slot_mapping
,
additional_kwargs
,
skip_compiled
,
scatter_indexes_tensor
,
gather_indexes_tensor
,
enable_lightly_cp
,
enable_lightly_cplb
)
try
:
...
...
vllm/model_executor/layers/mla.py
View file @
aef3c487
...
...
@@ -7,8 +7,12 @@ import torch
from
vllm.attention.layer
import
MLAAttention
from
vllm.config
import
CacheConfig
import
vllm.envs
as
envs
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.custom_op
import
PluggableLayer
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
)
@
dataclass
...
...
@@ -184,8 +188,26 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if
llama_4_scaling
is
not
None
:
q
*=
llama_4_scaling
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
enable_lightly_cplb
=
get_forward_context
().
enable_lightly_cplb
# if not use_fused_rms_rope_concat:
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
if
enable_lightly_cp
:
kv_c_normed
=
tensor_model_parallel_all_gather
(
kv_c_normed
.
contiguous
(),
0
)
k_pe
=
tensor_model_parallel_all_gather
(
k_pe
.
contiguous
(),
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
enable_lightly_cplb
and
gather_indexes_tensor
is
not
None
:
# Reorder kv after pcp allgather.
kv_c_normed
=
torch
.
index_select
(
kv_c_normed
,
0
,
gather_indexes_tensor
)
k_pe
=
torch
.
index_select
(
k_pe
,
0
,
gather_indexes_tensor
)
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
...
...
@@ -221,6 +243,20 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires rotary_emb to "
"expose 'cos_sin_cache'."
)
if
enable_lightly_cp
:
kv_c
=
tensor_model_parallel_all_gather
(
kv_c
.
contiguous
(),
0
)
k_pe
=
tensor_model_parallel_all_gather
(
k_pe
.
contiguous
(),
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
enable_lightly_cplb
and
gather_indexes_tensor
is
not
None
:
# Reorder kv after pcp allgather.
kv_c
=
torch
.
index_select
(
kv_c
,
0
,
gather_indexes_tensor
)
k_pe
=
torch
.
index_select
(
k_pe
,
0
,
gather_indexes_tensor
)
attn_out
=
self
.
mla_attn
(
q
[...,
self
.
qk_nope_head_dim
:],
kv_c
,
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
aef3c487
...
...
@@ -90,7 +90,7 @@ def sparse_attn_indexer(
)
attn_metadata
=
attn_metadata
[
layer_name
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
slot_mapping
=
attn_metadata
.
slot_mapping
[:
attn_metadata
.
num_kv_actual_tokens
]
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
aef3c487
...
...
@@ -11,6 +11,7 @@ import torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.forward_context
import
get_forward_context
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
...
...
@@ -36,6 +37,9 @@ from .deepseek_v2 import (
DeepseekV2MoE
,
get_spec_layer_idx_from_weight_name
,
)
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
.utils
import
maybe_prefix
from
.interfaces
import
SupportsPP
from
vllm
import
_custom_ops
as
ops
...
...
@@ -177,6 +181,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
...
...
@@ -191,7 +198,28 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
return
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
if
enable_lightly_cp
:
scatter_indexes_tensor
=
get_forward_context
().
scatter_indexes_tensor
if
scatter_indexes_tensor
is
None
:
inputs_embeds_per_rank
=
torch
.
chunk
(
inputs_embeds
,
chunks
=
self
.
tp_size
,
dim
=
0
)
inputs_embeds
=
inputs_embeds_per_rank
[
self
.
tp_rank
].
contiguous
()
previous_hidden_states_per_rank
=
torch
.
chunk
(
previous_hidden_states
,
chunks
=
self
.
tp_size
,
dim
=
0
)
previous_hidden_states
=
previous_hidden_states_per_rank
[
self
.
tp_rank
].
contiguous
()
if
positions
is
not
None
:
positions_per_rank
=
torch
.
chunk
(
positions
,
chunks
=
self
.
tp_size
,
dim
=
0
)
positions
=
positions_per_rank
[
self
.
tp_rank
].
contiguous
()
else
:
scatter_indexes_tensor
=
torch
.
where
(
scatter_indexes_tensor
==
-
1
,
0
,
scatter_indexes_tensor
)
inputs_embeds
=
torch
.
index_select
(
inputs_embeds
,
0
,
scatter_indexes_tensor
)
previous_hidden_states
=
torch
.
index_select
(
previous_hidden_states
,
0
,
scatter_indexes_tensor
)
if
positions
is
not
None
:
positions
=
torch
.
index_select
(
positions
,
0
,
scatter_indexes_tensor
)
hidden_states
=
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
input_ids
,
positions
,
previous_hidden_states
,
...
...
@@ -199,6 +227,14 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx
,
)
if
enable_lightly_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
gather_indexes_tensor
is
not
None
:
hidden_states
=
torch
.
index_select
(
hidden_states
,
0
,
gather_indexes_tensor
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
aef3c487
...
...
@@ -46,6 +46,8 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
)
from
vllm.forward_context
import
get_forward_context
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
,
tensor_model_parallel_reduce_scatter
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
...
...
@@ -181,6 +183,44 @@ class DeepseekAttention(nn.Module):
return
output
def
iqis_all_gather
(
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
tp_size
:
int
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
iqis
is
not
None
iq_tensor
,
is_tensor
=
iqis
assert
isinstance
(
iq_tensor
,
torch
.
Tensor
)
assert
isinstance
(
is_tensor
,
torch
.
Tensor
)
assert
iq_tensor
.
dtype
==
torch
.
int8
,
f
"iq_tensor dtype is
{
iq_tensor
.
dtype
}
"
assert
is_tensor
.
dtype
==
torch
.
float32
,
f
"is_tensor dtype is
{
is_tensor
.
dtype
}
"
assert
iq_tensor
.
dim
()
==
2
assert
is_tensor
.
dim
()
==
2
m_local
,
n
=
iq_tensor
.
shape
assert
is_tensor
.
shape
[
0
]
==
m_local
,
f
"
{
is_tensor
.
shape
[
0
]
}
!=
{
iq_tensor
.
shape
[
0
]
}
"
assert
is_tensor
.
shape
[
1
]
==
1
,
f
"is_tensor dim 1 =
{
is_tensor
.
shape
[
1
]
}
"
iq_int8_2d
=
iq_tensor
.
view
(
torch
.
int8
)
is_int8_2d
=
is_tensor
.
view
(
torch
.
int8
)
combined_2d
=
torch
.
cat
([
iq_int8_2d
,
is_int8_2d
],
dim
=
1
)
# [m_local, n + 4]
if
not
combined_2d
.
is_contiguous
():
combined_2d
=
combined_2d
.
contiguous
()
combined_gathered
=
tensor_model_parallel_all_gather
(
combined_2d
,
dim
=
0
)
split_idx
=
n
iq_gathered_int8
=
combined_gathered
[:,
:
split_idx
].
contiguous
()
is_gathered_int8
=
combined_gathered
[:,
split_idx
:].
contiguous
()
iq_gathered
=
iq_gathered_int8
.
view
(
torch
.
int8
)
assert
iq_gathered
.
shape
[
0
]
==
m_local
*
tp_size
,
f
"iq_gathered dim0=
{
iq_gathered
.
shape
[
0
]
}
, expected
{
m_local
*
tp_size
}
"
# is_gathered_int8 should be [m_local*tp_size, 4]
assert
is_gathered_int8
.
shape
[
0
]
==
m_local
*
tp_size
,
f
"is_gathered_int8 dim0=
{
is_gathered_int8
.
shape
[
0
]
}
, expected
{
m_local
*
tp_size
}
"
assert
is_gathered_int8
.
shape
[
1
]
==
4
,
f
"is_gathered_int8 dim1=
{
is_gathered_int8
.
shape
[
1
]
}
"
is_gathered
=
is_gathered_int8
.
view
(
torch
.
float32
)
return
(
iq_gathered
,
is_gathered
)
class
DeepseekV2MLP
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -211,10 +251,85 @@ class DeepseekV2MLP(nn.Module):
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
#reduce_results=reduce_results,
reduce_results
=
False
,
disable_tp
=
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
):
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
if
enable_lightly_cp
:
if
iqis
is
not
None
and
iqis
[
0
]
is
not
None
and
iqis
[
1
]
is
not
None
:
iqis
=
iqis_all_gather
(
iqis
,
tp_size
=
self
.
tp_size
)
else
:
x
=
tensor_model_parallel_all_gather
(
x
.
contiguous
(),
0
)
if
envs
.
USE_FUSED_RMS_QUANT
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
iqis
=
iqis
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
from
lmslim.quantize.quant_ops
import
lm_fuse_silu_mul_quant
xq
,
xs
=
lm_fuse_silu_mul_quant
(
gate_up
)
x
,
_
=
self
.
down_proj
(
gate_up
,
iqis
=
(
xq
,
xs
))
else
:
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
else
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
if
enable_lightly_cp
:
x
=
tensor_model_parallel_reduce_scatter
(
x
.
contiguous
(),
dim
=
0
)
return
x
elif
self
.
tp_size
>
1
:
x
=
tensor_model_parallel_all_reduce
(
x
)
return
x
class
DeepseekV2SharedMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
QuantizationConfig
|
None
=
None
,
reduce_results
:
bool
=
True
,
is_sequence_parallel
=
False
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
disable_tp
=
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
disable_tp
=
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
...
...
@@ -311,7 +426,7 @@ class DeepseekV2MoE(nn.Module):
else
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
DeepseekV2MLP
(
self
.
shared_experts
=
DeepseekV2
Shared
MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
...
...
@@ -357,6 +472,11 @@ class DeepseekV2MoE(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
if
enable_lightly_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
0
)
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
...
...
@@ -428,7 +548,12 @@ class DeepseekV2MoE(nn.Module):
assert
shared_output
is
not
None
final_hidden_states
+=
shared_output
if
self
.
is_sequence_parallel
:
if
enable_lightly_cp
:
final_hidden_states
=
tensor_model_parallel_reduce_scatter
(
final_hidden_states
.
contiguous
(),
0
)
return
final_hidden_states
elif
self
.
is_sequence_parallel
:
final_hidden_states
=
tensor_model_parallel_all_gather
(
final_hidden_states
,
0
)
...
...
@@ -759,6 +884,16 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
if
enable_lightly_cp
:
k
=
tensor_model_parallel_all_gather
(
k
.
contiguous
(),
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
enable_lightly_cplb
=
get_forward_context
().
enable_lightly_cplb
if
enable_lightly_cplb
and
gather_indexes_tensor
is
not
None
:
k
=
torch
.
index_select
(
k
,
0
,
gather_indexes_tensor
)
# we only quant q here since k quant is fused with cache insertion
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
q
=
q
.
view
(
-
1
,
self
.
head_dim
)
...
...
@@ -825,7 +960,8 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
num_heads
=
num_heads
tp_size
=
get_tensor_model_parallel_world_size
()
assert
num_heads
%
tp_size
==
0
self
.
num_local_heads
=
num_heads
//
tp_size
self
.
num_local_heads
=
num_heads
//
tp_size
if
not
\
vllm_config
.
parallel_config
.
enable_lightly_cp
else
self
.
num_heads
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
max_position_embeddings
=
max_position_embeddings
...
...
@@ -859,6 +995,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_b_proj"
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
)
else
:
self
.
q_proj
=
ColumnParallelLinear
(
...
...
@@ -867,6 +1004,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_proj"
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
,
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
...
...
@@ -875,6 +1013,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.kv_b_proj"
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
v_head_dim
,
...
...
@@ -882,6 +1021,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
,
)
if
config
.
rope_parameters
[
"rope_type"
]
!=
"default"
:
...
...
@@ -1118,6 +1258,7 @@ class DeepseekV2DecoderLayer(nn.Module):
residual
*=
1.0
/
self
.
routed_scaling_factor
# Fully Connected
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
update_hs
=
True
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
else
False
assert
self
.
post_attention_layernorm
.
has_weight
is
True
_i_q
,
_i_s
,
residual
=
self
.
post_attention_layernorm
(
x
=
hidden_states
,
...
...
@@ -1126,9 +1267,10 @@ class DeepseekV2DecoderLayer(nn.Module):
update_input
=
update_hs
)
new_resi
=
residual
hidden_states
=
self
.
mlp
(
hidden_states
,
iqis
=
(
_i_q
,
_i_s
)
)
if
enable_lightly_cp
and
isinstance
(
self
.
mlp
,
DeepseekV2MoE
):
hidden_states
=
self
.
mlp
(
hidden_states
)
else
:
hidden_states
=
self
.
mlp
(
hidden_states
,
iqis
=
(
_i_q
,
_i_s
))
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
...
...
@@ -1225,6 +1367,9 @@ class DeepseekV2Model(nn.Module):
self
.
config
=
config
self
.
device
=
current_platform
.
device_type
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
vocab_size
=
config
.
vocab_size
#添加判断,默认开启DSA
force_disable_dsa
=
envs
.
VLLM_DISABLE_DSA
...
...
@@ -1287,6 +1432,30 @@ class DeepseekV2Model(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
if
enable_lightly_cp
:
scatter_indexes_tensor
=
get_forward_context
().
scatter_indexes_tensor
if
scatter_indexes_tensor
is
None
:
hidden_states_per_rank
=
torch
.
chunk
(
hidden_states
,
chunks
=
self
.
tp_size
,
dim
=
0
)
hidden_states
=
hidden_states_per_rank
[
self
.
tp_rank
].
contiguous
()
if
residual
is
not
None
:
residual_per_rank
=
torch
.
chunk
(
residual
,
chunks
=
self
.
tp_size
,
dim
=
0
)
residual
=
residual_per_rank
[
self
.
tp_rank
].
contiguous
()
if
positions
is
not
None
:
positions_per_rank
=
torch
.
chunk
(
positions
,
chunks
=
self
.
tp_size
,
dim
=
0
)
positions
=
positions_per_rank
[
self
.
tp_rank
].
contiguous
()
else
:
scatter_indexes_tensor
=
torch
.
where
(
scatter_indexes_tensor
==
-
1
,
0
,
scatter_indexes_tensor
)
hidden_states
=
torch
.
index_select
(
hidden_states
,
0
,
scatter_indexes_tensor
)
if
residual
is
not
None
:
residual
=
torch
.
index_select
(
residual
,
0
,
scatter_indexes_tensor
)
if
positions
is
not
None
:
positions
=
torch
.
index_select
(
positions
,
0
,
scatter_indexes_tensor
)
# Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config
=
getattr
(
self
.
config
,
"llama_4_scaling"
,
None
)
llama_4_scaling
:
torch
.
Tensor
|
None
...
...
@@ -1307,11 +1476,21 @@ class DeepseekV2Model(nn.Module):
)
if
not
get_pp_group
().
is_last_rank
:
if
enable_lightly_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
residual
=
tensor_model_parallel_all_gather
(
residual
.
contiguous
(),
dim
=
0
)
return
IntermediateTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
}
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
enable_lightly_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
gather_indexes_tensor
is
not
None
:
hidden_states
=
torch
.
index_select
(
hidden_states
,
0
,
gather_indexes_tensor
)
return
hidden_states
...
...
vllm/v1/attention/backend.py
View file @
aef3c487
...
...
@@ -282,6 +282,35 @@ class AttentionMetadata:
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadata
)
@
dataclass
class
CpCommonAttentionMetadata
:
# sp related metadata
query_start_loc
:
torch
.
Tensor
query_start_loc_cpu
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
_seq_lens_cpu
:
torch
.
Tensor
num_actual_tokens
:
int
num_kv_actual_tokens
:
int
max_query_len
:
int
max_seq_len
:
int
num_reqs
:
int
req_ids
:
list
[
str
]
block_table_tensor
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
_num_computed_tokens_cpu
:
torch
.
Tensor
dcp_local_seq_lens
:
torch
.
Tensor
|
None
=
None
dcp_local_seq_lens_cpu
:
torch
.
Tensor
|
None
=
None
def
batch_size
(
self
)
->
int
:
return
self
.
seq_lens
.
shape
[
0
]
@
property
def
seq_lens_cpu
(
self
)
->
torch
.
Tensor
:
if
self
.
_seq_lens_cpu
is
None
:
self
.
_seq_lens_cpu
=
self
.
seq_lens
.
to
(
"cpu"
)
return
self
.
_seq_lens_cpu
@
dataclass
class
CommonAttentionMetadata
:
...
...
@@ -312,6 +341,14 @@ class CommonAttentionMetadata:
block_table_tensor
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
num_kv_actual_tokens
:
int
|
None
=
None
seq_indexes_list
:
list
[
int
]
|
None
=
None
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
cp_common_metadata
:
CpCommonAttentionMetadata
|
None
=
None
enable_lightly_cp
:
bool
=
False
causal
:
bool
=
True
# Needed by FastPrefillAttentionBuilder
...
...
@@ -396,6 +433,7 @@ class CommonAttentionMetadata:
else
None
,
num_reqs
=
num_actual_reqs
,
num_actual_tokens
=
num_actual_tokens
,
num_kv_actual_tokens
=
num_actual_tokens
,
max_query_len
=
self
.
max_query_len
,
max_seq_len
=
self
.
max_seq_len
,
block_table_tensor
=
self
.
block_table_tensor
[:
num_actual_reqs
],
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
aef3c487
...
...
@@ -138,6 +138,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
max_seq_len
:
int
num_actual_tokens
:
int
# Number of tokens excluding padding.
num_kv_actual_tokens
:
int
query_start_loc
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
...
...
@@ -693,6 +694,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
max_query_len
=
cm
.
max_query_len
,
max_seq_len
=
cm
.
max_seq_len
,
num_actual_tokens
=
cm
.
num_actual_tokens
,
num_kv_actual_tokens
=
cm
.
num_kv_actual_tokens
,
query_start_loc
=
cm
.
query_start_loc
,
slot_mapping
=
cm
.
slot_mapping
,
block_table
=
cm
.
block_table_tensor
,
...
...
@@ -1024,12 +1026,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
return
output
.
fill_
(
0
)
num_actual_toks
=
attn_metadata
.
num_actual_tokens
num_kv_actual_toks
=
attn_metadata
.
num_kv_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q
=
q
[:
num_actual_toks
,
...]
k_c_normed
=
k_c_normed
[:
num_actual_toks
,
...]
k_pe
=
k_pe
[:
num_actual_toks
,
...]
k_c_normed
=
k_c_normed
[:
num_
kv_
actual_toks
,
...]
k_pe
=
k_pe
[:
num_
kv_
actual_toks
,
...]
assert
self
.
topk_indices_buffer
is
not
None
topk_indices
=
self
.
topk_indices_buffer
[:
num_actual_toks
]
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
aef3c487
...
...
@@ -105,6 +105,7 @@ class DeepseekV32IndexerMetadata:
max_seq_len
:
int
num_actual_tokens
:
int
# Number of tokens excluding padding.
num_kv_actual_tokens
:
int
query_start_loc
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
# The dimension of the attention heads
...
...
@@ -438,6 +439,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
max_query_len
=
common_attn_metadata
.
max_query_len
,
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
,
num_kv_actual_tokens
=
common_attn_metadata
.
num_kv_actual_tokens
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
head_dim
=
128
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
aef3c487
...
...
@@ -14,7 +14,7 @@ from vllm.config import (
VllmConfig
,
get_layers_from_vllm_config
,
)
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tensor_model_parallel_rank
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
...
...
@@ -29,6 +29,7 @@ from vllm.utils.platform_utils import is_pin_memory_available
from
vllm.v1.attention.backend
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
CpCommonAttentionMetadata
,
)
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.tree_attn
import
(
...
...
@@ -48,6 +49,7 @@ from vllm.v1.spec_decode.utils import (
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.utils.math_utils
import
cdiv
,
round_up
logger
=
init_logger
(
__name__
)
...
...
@@ -76,7 +78,9 @@ class SpecDecodeBaseProposer:
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
# The drafter can get longer sequences than the target model.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
if
not
\
vllm_config
.
parallel_config
.
enable_lightly_cplb
\
else
vllm_config
.
scheduler_config
.
max_num_seqs
*
2
self
.
max_num_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
)
...
...
@@ -219,6 +223,28 @@ class SpecDecodeBaseProposer:
1
,
len
(
self
.
tree_choices
)
+
1
,
device
=
device
,
dtype
=
torch
.
int32
).
repeat
(
max_batch_size
,
1
)
self
.
scatter_indexes_tensor
=
None
self
.
gather_indexes_tensor
=
None
self
.
enable_lightly_cp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
self
.
enable_lightly_cplb
=
self
.
enable_lightly_cp
and
vllm_config
.
parallel_config
.
enable_lightly_cplb
if
self
.
enable_lightly_cp
:
self
.
query_start_loc
=
CpuGpuBuffer
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
pin_memory
=
is_pin_memory_available
(),
device
=
device
,
with_numpy
=
True
,
)
self
.
seq_lens
=
CpuGpuBuffer
(
max_batch_size
,
dtype
=
torch
.
int32
,
pin_memory
=
is_pin_memory_available
(),
device
=
device
,
with_numpy
=
True
,
)
def
_get_positions
(
self
,
num_tokens
:
int
):
if
self
.
uses_mrope
:
return
self
.
mrope_positions
[:,
:
num_tokens
]
...
...
@@ -270,6 +296,10 @@ class SpecDecodeBaseProposer:
self
.
cudagraph_dispatcher
.
initialize_cudagraph_keys
(
eagle_cudagraph_mode
)
def
_pad_for_mla_cp
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
return
round_up
(
num_scheduled_tokens
,
tp_size
)
def
propose
(
self
,
# [num_tokens]
...
...
@@ -309,6 +339,31 @@ class SpecDecodeBaseProposer:
)
)
num_tokens_dp_padded
,
num_tokens_across_dp
=
self
.
_pad_batch_across_dp
(
num_tokens_unpadded
=
num_tokens
,
num_tokens_padded
=
num_tokens
)
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens
>
self
.
runner
.
lightly_cp_threshould
if
enable_lightly_cp
:
num_tokens_dp_padded
=
self
.
_pad_for_mla_cp
(
num_tokens_dp_padded
)
common_attn_metadata
=
self
.
_prepare_cp_metadata
(
num_reqs_padded
=
common_attn_metadata
.
num_reqs
,
max_query_len
=
common_attn_metadata
.
max_query_len
,
max_seq_len
=
common_attn_metadata
.
seq_lens_cpu
.
max
().
item
(),
num_tokens
=
num_tokens
,
block_table_gid_0
=
common_attn_metadata
.
block_table_tensor
,
slot_mapping_gid_0
=
common_attn_metadata
.
slot_mapping
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
,
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
,
seq_lens
=
common_attn_metadata
.
seq_lens
,
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
,
num_computed_tokens_cpu
=
common_attn_metadata
.
_num_computed_tokens_cpu
,
)
self
.
scatter_indexes_tensor
=
common_attn_metadata
.
scatter_indexes_tensor
self
.
gather_indexes_tensor
=
common_attn_metadata
.
gather_indexes_tensor
assert
self
.
runner
is
not
None
if
self
.
attn_metadata_builder
is
None
:
...
...
@@ -339,10 +394,6 @@ class SpecDecodeBaseProposer:
assert
draft_indexer_metadata
is
not
None
per_layer_attn_metadata
[
layer_name
]
=
draft_indexer_metadata
num_tokens_dp_padded
,
num_tokens_across_dp
=
self
.
_pad_batch_across_dp
(
num_tokens_unpadded
=
num_tokens
,
num_tokens_padded
=
num_tokens
)
cudagraph_runtime_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
num_tokens_dp_padded
)
...
...
@@ -387,6 +438,10 @@ class SpecDecodeBaseProposer:
slot_mapping
=
self
.
_get_slot_mapping
(
num_input_tokens
,
common_attn_metadata
.
slot_mapping
),
scatter_indexes_tensor
=
self
.
scatter_indexes_tensor
,
gather_indexes_tensor
=
self
.
gather_indexes_tensor
,
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens
>
self
.
runner
.
lightly_cp_threshould
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
):
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
if
not
self
.
model_returns_tuple
():
...
...
@@ -463,6 +518,9 @@ class SpecDecodeBaseProposer:
if
batch_size_across_dp
is
not
None
:
batch_size_across_dp
[
self
.
dp_rank
]
=
input_batch_size
if
enable_lightly_cp
:
common_attn_metadata
=
common_attn_metadata
.
cp_common_metadata
common_attn_metadata
.
num_actual_tokens
=
batch_size
common_attn_metadata
.
max_query_len
=
1
common_attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
...
...
@@ -802,6 +860,7 @@ class SpecDecodeBaseProposer:
_num_computed_tokens_cpu
=
common_attn_metadata
.
_num_computed_tokens_cpu
,
num_reqs
=
common_attn_metadata
.
num_reqs
,
num_actual_tokens
=
total_num_tokens
,
num_kv_actual_tokens
=
total_num_tokens
,
max_query_len
=
new_query_len_per_req
.
max
().
item
(),
max_seq_len
=
common_attn_metadata
.
seq_lens_cpu
.
max
().
item
(),
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
,
...
...
@@ -988,6 +1047,104 @@ class SpecDecodeBaseProposer:
level_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
1
]
-
total_num_drafts
total_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
1
]
return
draft_token_ids_list
def
_prepare_cp_metadata
(
self
,
num_reqs_padded
,
max_query_len
,
max_seq_len
,
num_tokens
,
block_table_gid_0
,
slot_mapping_gid_0
,
query_start_loc
,
query_start_loc_cpu
,
seq_lens
,
seq_lens_cpu
,
num_computed_tokens_cpu
,
):
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
tp_rank
=
get_tensor_model_parallel_rank
()
cp_common_metadata
=
CpCommonAttentionMetadata
(
query_start_loc
=
query_start_loc
.
clone
(),
query_start_loc_cpu
=
query_start_loc_cpu
.
clone
(),
seq_lens
=
seq_lens
.
clone
(),
_seq_lens_cpu
=
seq_lens_cpu
.
clone
(),
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
num_reqs
=
num_reqs_padded
,
req_ids
=
self
.
runner
.
input_batch
.
req_ids
,
num_actual_tokens
=
num_tokens
,
num_kv_actual_tokens
=
num_tokens
,
block_table_tensor
=
block_table_gid_0
,
slot_mapping
=
slot_mapping_gid_0
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
)
q_lens_cpu
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
kv_lens_cpu
=
seq_lens_cpu
total_q_len
=
num_tokens
total_kv_len
=
num_tokens
(
total_q_len
,
q_lens_cpu
,
seq_count
,
kv_lens_cpu
,
local_req_ids
,
scatter_indexes_tensor
,
gather_indexes_tensor
,
seq_indexes_list
,
)
=
self
.
runner
.
_distribute_tokens_to_cp_ranks
(
total_q_len
,
q_lens_cpu
,
kv_lens_cpu
,
tp_rank
,
tp_size
,
self
.
runner
.
input_batch
.
req_ids
,
)
num_reqs
=
seq_count
cu_num_tokens
=
np
.
cumsum
(
q_lens_cpu
)
self
.
query_start_loc
.
np
[
0
]
=
0
self
.
query_start_loc
.
np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
self
.
query_start_loc
.
np
[
num_reqs
+
1
:].
fill
(
cu_num_tokens
[
-
1
])
self
.
query_start_loc
.
copy_to_gpu
()
q_acc_lens
=
self
.
query_start_loc
.
gpu
[:
num_reqs
+
1
]
q_acc_lens_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs
+
1
]
max_q_len
=
max
(
q_acc_lens_cpu
)
self
.
seq_lens
.
np
[:
num_reqs
]
=
kv_lens_cpu
self
.
seq_lens
.
np
[
num_reqs
:].
fill
(
0
)
self
.
seq_lens
.
copy_to_gpu
()
kv_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs
]
kv_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs
]
max_kv_len
=
max
(
kv_lens_cpu
)
num_computed_tokens_cpu
=
kv_lens_cpu
-
q_acc_lens_cpu
[
1
:]
blk_table_tensor
=
block_table_gid_0
[
seq_indexes_list
]
cm_base
=
CommonAttentionMetadata
(
query_start_loc
=
q_acc_lens
,
query_start_loc_cpu
=
q_acc_lens_cpu
,
seq_lens
=
kv_lens
,
_seq_lens_cpu
=
kv_lens_cpu
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_q_len
,
max_query_len
=
max_q_len
,
max_seq_len
=
max_kv_len
,
block_table_tensor
=
blk_table_tensor
,
slot_mapping
=
slot_mapping_gid_0
,
causal
=
True
,
num_kv_actual_tokens
=
total_kv_len
,
seq_indexes_list
=
seq_indexes_list
,
cp_common_metadata
=
cp_common_metadata
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
)
return
cm_base
def
prepare_inputs
(
self
,
...
...
vllm/v1/worker/block_table.py
View file @
aef3c487
...
...
@@ -233,6 +233,10 @@ class BlockTable:
def
get_device_tensor
(
self
,
num_reqs
:
int
)
->
torch
.
Tensor
:
"""Returns the device tensor of the block table."""
return
self
.
block_table
.
gpu
[:
num_reqs
]
def
get_device_tensor_range
(
self
,
start_req
:
int
,
end_req
:
int
)
->
torch
.
Tensor
:
"""Returns the device tensor of the block table."""
return
self
.
block_table
.
gpu
[
start_req
:
end_req
]
def
get_cpu_tensor
(
self
)
->
torch
.
Tensor
:
"""Returns the CPU tensor of the block table."""
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
aef3c487
...
...
@@ -42,8 +42,13 @@ from vllm.distributed.parallel_state import (
get_tp_group
,
graph_capture
,
is_global_first_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
prepare_communication_buffer_for_model
,
)
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
)
from
vllm.forward_context
import
(
BatchDescriptor
,
set_forward_context
,
...
...
@@ -104,6 +109,7 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder
,
AttentionType
,
CommonAttentionMetadata
,
CpCommonAttentionMetadata
,
MultipleOf
,
)
from
vllm.v1.attention.backends.gdn_attn
import
GDNAttentionMetadataBuilder
...
...
@@ -372,10 +378,20 @@ class GPUModelRunner(
# Always set to false after the first forward pass
self
.
calculate_kv_scales
=
self
.
cache_config
.
calculate_kv_scales
self
.
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
self
.
dcp_world_size
=
self
.
parallel_config
.
decode_context_parallel_size
self
.
dcp_rank
=
0
if
self
.
dcp_world_size
<=
1
else
get_dcp_group
().
rank_in_group
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
#self.max_num_reqs = scheduler_config.max_num_seqs
self
.
enable_lightly_cp
=
self
.
parallel_config
.
enable_lightly_cp
self
.
enable_lightly_cplb
=
self
.
enable_lightly_cp
and
self
.
parallel_config
.
enable_lightly_cplb
self
.
max_num_reqs
=
(
scheduler_config
.
max_num_seqs
if
not
self
.
enable_lightly_cplb
else
scheduler_config
.
max_num_seqs
*
2
)
self
.
lightly_cp_threshould
=
envs
.
VLLM_LIGHTLY_CP_THRESHOULD
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
...
...
@@ -1490,6 +1506,243 @@ class GPUModelRunner(
return
encoder_seq_lens
,
encoder_seq_lens_cpu
def
_distribute_tokens_to_cp_ranks
(
self
,
total_q_len
:
int
,
q_lens_cpu
:
np
.
ndarray
,
kv_lens_cpu
:
np
.
ndarray
,
tp_rank
:
int
,
tp_size
:
int
,
req_ids
:
list
[
str
],
):
q_lens
=
[]
seq_count
=
0
seq_indexes
=
[]
kv_lens
=
[]
local_req_ids
=
[]
local_scatter_indexes_tensor
=
None
gather_indexes_tensor
=
None
if
self
.
enable_lightly_cplb
:
rank_tokens
=
0
rank_pad_tokens
=
0
accu_q_start
=
0
scatter_indexes
:
list
[
int
]
=
[]
num_requests
=
len
(
q_lens_cpu
)
for
i
in
range
(
num_requests
):
req_q_len
=
q_lens_cpu
[
i
]
req_pad_q_len
=
round_up
(
q_lens_cpu
[
i
],
2
*
tp_size
)
kv_len
=
kv_lens_cpu
[
i
]
chunk_q_len
=
req_pad_q_len
//
(
2
*
tp_size
)
q_1_start
=
tp_rank
*
chunk_q_len
q_1_end
=
(
tp_rank
+
1
)
*
chunk_q_len
q_2_start
=
req_pad_q_len
-
(
tp_rank
+
1
)
*
chunk_q_len
q_2_end
=
req_pad_q_len
-
tp_rank
*
chunk_q_len
q_len_1
=
(
chunk_q_len
if
q_1_end
<=
req_q_len
else
max
(
0
,
req_q_len
-
q_1_start
)
)
q_len_2
=
(
chunk_q_len
if
q_2_end
<=
req_q_len
else
max
(
0
,
req_q_len
-
q_2_start
)
)
kv_len_1
=
kv_len
-
req_q_len
+
min
(
req_q_len
,
q_1_end
)
kv_len_2
=
kv_len
-
req_q_len
+
min
(
req_q_len
,
q_2_end
)
scatter_index1
=
range
(
accu_q_start
+
q_1_start
,
accu_q_start
+
q_1_start
+
q_len_1
)
scatter_index2
=
range
(
accu_q_start
+
q_2_start
,
accu_q_start
+
q_2_start
+
q_len_2
)
accu_q_start
+=
req_q_len
if
q_len_1
>
0
:
q_lens
.
append
(
q_len_1
)
kv_lens
.
append
(
kv_len_1
)
seq_indexes
.
append
(
i
)
local_req_ids
.
append
(
req_ids
[
i
])
scatter_indexes
.
extend
(
scatter_index1
)
seq_count
+=
1
rank_tokens
+=
q_len_1
if
q_len_2
>
0
:
q_lens
.
append
(
q_len_2
)
kv_lens
.
append
(
kv_len_2
)
seq_indexes
.
append
(
i
)
local_req_ids
.
append
(
req_ids
[
i
])
scatter_indexes
.
extend
(
scatter_index2
)
seq_count
+=
1
rank_tokens
+=
q_len_2
rank_pad_tokens
+=
chunk_q_len
*
2
if
len
(
scatter_indexes
)
<
rank_pad_tokens
:
scatter_indexes
.
extend
([
-
1
]
*
(
rank_pad_tokens
-
len
(
scatter_indexes
)))
local_scatter_indexes_tensor
=
torch
.
tensor
(
scatter_indexes
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
global_scatter_indexes_tensor
=
tensor_model_parallel_all_gather
(
local_scatter_indexes_tensor
.
contiguous
(),
dim
=
0
)
non_neg_mask
=
global_scatter_indexes_tensor
!=
-
1
non_neg_values
=
global_scatter_indexes_tensor
[
non_neg_mask
]
non_neg_positions
=
torch
.
where
(
non_neg_mask
)[
0
]
sorted_indices
=
torch
.
argsort
(
non_neg_values
)
gather_indexes_tensor
=
non_neg_positions
[
sorted_indices
]
if
isinstance
(
rank_tokens
,
torch
.
Tensor
):
rank_tokens
=
rank_tokens
.
item
()
else
:
tokens_per_rank
=
(
total_q_len
+
tp_size
-
1
)
//
tp_size
start_token
=
tp_rank
*
tokens_per_rank
end_token
=
min
((
tp_rank
+
1
)
*
tokens_per_rank
,
total_q_len
)
current_seq
=
0
current_pos
=
0
rank_tokens
=
min
(
tokens_per_rank
,
end_token
-
start_token
)
while
start_token
<
end_token
and
current_seq
<
len
(
q_lens_cpu
):
q_len
=
q_lens_cpu
[
current_seq
]
q_start
=
current_pos
q_end
=
current_pos
+
q_len
kv_len
=
kv_lens_cpu
[
current_seq
]
# Find overlap between this sequence and rank's token range
overlap_start
=
max
(
start_token
,
q_start
)
overlap_end
=
min
(
end_token
,
q_end
)
if
overlap_start
<
overlap_end
:
# This sequence contributes tokens to this rank
token_count
=
overlap_end
-
overlap_start
q_lens
.
append
(
token_count
)
start_token
=
overlap_end
seq_count
+=
1
seq_indexes
.
append
(
current_seq
)
local_req_ids
.
append
(
req_ids
[
current_seq
])
if
q_end
<=
end_token
:
kv_lens
.
append
(
kv_len
)
else
:
kv_lens
.
append
(
kv_len
-
(
q_end
-
end_token
))
current_pos
=
q_end
current_seq
+=
1
return
(
rank_tokens
,
np
.
array
(
q_lens
,
dtype
=
np
.
int32
),
seq_count
,
np
.
array
(
kv_lens
,
dtype
=
np
.
int32
),
np
.
array
(
local_req_ids
,
dtype
=
str
),
local_scatter_indexes_tensor
,
gather_indexes_tensor
,
seq_indexes
,
)
def
_prepare_cp_metadata
(
self
,
num_reqs_padded
,
max_query_len
,
max_seq_len
,
num_tokens
,
block_table_gid_0
,
slot_mapping_gid_0
,
num_computed_tokens_cpu
):
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
tp_rank
=
get_tensor_model_parallel_rank
()
cp_common_metadata
=
CpCommonAttentionMetadata
(
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
].
clone
(),
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
].
clone
(),
seq_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs_padded
].
clone
(),
_seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs_padded
].
clone
(),
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
num_reqs
=
num_reqs_padded
,
req_ids
=
self
.
input_batch
.
req_ids
,
num_actual_tokens
=
num_tokens
,
num_kv_actual_tokens
=
num_tokens
,
block_table_tensor
=
block_table_gid_0
,
slot_mapping
=
slot_mapping_gid_0
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
)
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
]
q_lens_cpu
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
kv_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs_padded
]
total_q_len
=
num_tokens
total_kv_len
=
num_tokens
(
total_q_len
,
q_lens_cpu
,
seq_count
,
kv_lens_cpu
,
local_req_ids
,
scatter_indexes_tensor
,
gather_indexes_tensor
,
seq_indexes_list
,
)
=
self
.
_distribute_tokens_to_cp_ranks
(
total_q_len
,
q_lens_cpu
,
kv_lens_cpu
,
tp_rank
,
tp_size
,
self
.
input_batch
.
req_ids
,
)
num_reqs
=
seq_count
cu_num_tokens
=
np
.
cumsum
(
q_lens_cpu
)
self
.
query_start_loc
.
np
[
0
]
=
0
self
.
query_start_loc
.
np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
self
.
query_start_loc
.
np
[
num_reqs
+
1
:].
fill
(
cu_num_tokens
[
-
1
])
self
.
query_start_loc
.
copy_to_gpu
()
q_acc_lens
=
self
.
query_start_loc
.
gpu
[:
num_reqs
+
1
]
q_acc_lens_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs
+
1
]
max_q_len
=
max
(
q_acc_lens_cpu
)
self
.
seq_lens
.
np
[:
num_reqs
]
=
kv_lens_cpu
self
.
seq_lens
.
np
[
num_reqs
:].
fill
(
0
)
self
.
seq_lens
.
copy_to_gpu
()
kv_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs
]
kv_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs
]
max_kv_len
=
max
(
kv_lens_cpu
)
num_computed_tokens_cpu
=
kv_lens_cpu
-
q_acc_lens_cpu
[
1
:]
blk_table_tensor
=
block_table_gid_0
[
seq_indexes_list
]
cm_base
=
CommonAttentionMetadata
(
query_start_loc
=
q_acc_lens
,
query_start_loc_cpu
=
q_acc_lens_cpu
,
seq_lens
=
kv_lens
,
_seq_lens_cpu
=
kv_lens_cpu
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_q_len
,
max_query_len
=
max_q_len
,
max_seq_len
=
max_kv_len
,
block_table_tensor
=
blk_table_tensor
,
slot_mapping
=
slot_mapping_gid_0
,
causal
=
True
,
num_kv_actual_tokens
=
total_kv_len
,
seq_indexes_list
=
seq_indexes_list
,
cp_common_metadata
=
cp_common_metadata
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
enable_lightly_cp
=
True
)
return
cm_base
def
_prepare_inputs
(
self
,
scheduler_output
:
"SchedulerOutput"
,
...
...
@@ -1723,13 +1976,20 @@ class GPUModelRunner(
num_scheduled_tokens
:
dict
[
str
,
int
]
|
None
=
None
,
cascade_attn_prefix_lens
:
list
[
list
[
int
]]
|
None
=
None
,
slot_mappings
:
dict
[
int
,
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
PerLayerAttnMetadata
,
CommonAttentionMetadata
|
None
]:
)
->
tuple
[
PerLayerAttnMetadata
,
CommonAttentionMetadata
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
]:
"""
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
"""
# Attention metadata is not needed for attention free models
if
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
0
:
return
{},
None
return
{},
None
,
None
,
None
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
num_tokens_padded
=
num_tokens_padded
or
num_tokens
num_reqs_padded
=
num_reqs_padded
or
num_reqs
...
...
@@ -1777,25 +2037,45 @@ class GPUModelRunner(
assert
slot_mappings
is
not
None
block_table_gid_0
=
_get_block_table
(
0
)
slot_mapping_gid_0
=
slot_mappings
[
0
]
scatter_indexes_tensor
=
None
gather_indexes_tensor
=
None
if
self
.
model_config
.
enable_return_routed_experts
:
self
.
slot_mapping
=
slot_mapping_gid_0
[:
num_tokens
].
cpu
().
numpy
()
cm_base
=
CommonAttentionMetadata
(
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
],
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
],
seq_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs_padded
],
_seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs_padded
],
_num_computed_tokens_cpu
=
self
.
input_batch
.
num_computed_tokens_cpu_tensor
[
:
num_reqs_padded
],
num_reqs
=
num_reqs_padded
,
num_actual_tokens
=
num_tokens_padded
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
block_table_tensor
=
block_table_gid_0
,
slot_mapping
=
slot_mapping_gid_0
,
causal
=
True
,
)
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens
>
self
.
lightly_cp_threshould
if
not
enable_lightly_cp
:
cm_base
=
CommonAttentionMetadata
(
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
],
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
],
seq_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs_padded
],
_seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs_padded
],
_num_computed_tokens_cpu
=
self
.
input_batch
.
num_computed_tokens_cpu_tensor
[
:
num_reqs_padded
],
num_reqs
=
num_reqs_padded
,
num_actual_tokens
=
num_tokens_padded
,
num_kv_actual_tokens
=
num_tokens_padded
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
block_table_tensor
=
block_table_gid_0
,
slot_mapping
=
slot_mapping_gid_0
,
causal
=
True
,
)
else
:
cm_base
=
self
.
_prepare_cp_metadata
(
num_reqs_padded
,
max_query_len
,
max_seq_len
,
num_tokens
,
block_table_gid_0
,
slot_mapping_gid_0
,
self
.
input_batch
.
num_computed_tokens_cpu_tensor
[
:
num_reqs_padded
],
)
scatter_indexes_tensor
=
cm_base
.
scatter_indexes_tensor
gather_indexes_tensor
=
cm_base
.
gather_indexes_tensor
if
self
.
dcp_world_size
>
1
:
self
.
dcp_local_seq_lens
.
cpu
[:
num_reqs
]
=
get_dcp_local_seq_lens
(
...
...
@@ -1906,12 +2186,23 @@ class GPUModelRunner(
cm
.
block_table_tensor
=
_get_block_table
(
kv_cache_gid
)
cm
.
slot_mapping
=
slot_mappings
[
kv_cache_gid
]
if
enable_lightly_cp
and
cm
.
seq_indexes_list
is
not
None
:
cm
.
block_table_tensor
=
cm
.
block_table_tensor
[
cm
.
seq_indexes_list
]
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
and
hasattr
(
self
,
"drafter"
):
if
isinstance
(
self
.
drafter
,
EagleProposer
):
if
self
.
drafter
.
attn_layer_names
[
0
]
in
kv_cache_group
.
layer_names
:
spec_decode_common_attn_metadata
=
cm
if
enable_lightly_cp
:
spec_decode_common_attn_metadata
=
cm
.
cp_common_metadata
else
:
spec_decode_common_attn_metadata
=
cm
#spec_decode_common_attn_metadata = cm
else
:
spec_decode_common_attn_metadata
=
cm
if
enable_lightly_cp
:
spec_decode_common_attn_metadata
=
cm
.
cp_common_metadata
else
:
spec_decode_common_attn_metadata
=
cm
#spec_decode_common_attn_metadata = cm
for
attn_gid
in
range
(
len
(
self
.
attn_groups
[
kv_cache_gid
])):
if
ubatch_slices
is
not
None
:
...
...
@@ -1941,8 +2232,10 @@ class GPUModelRunner(
for
_metadata
in
attn_metadata
.
values
():
_metadata
.
mm_prefix_range
=
req_doc_ranges
# type: ignore[attr-defined]
if
spec_decode_common_attn_metadata
is
not
None
and
(
num_reqs
!=
num_reqs_padded
or
num_tokens
!=
num_tokens_padded
if
(
(
not
self
.
enable_lightly_cp
)
and
spec_decode_common_attn_metadata
is
not
None
and
(
num_reqs
!=
num_reqs_padded
or
num_tokens
!=
num_tokens_padded
)
):
# Currently the drafter still only uses piecewise cudagraphs (and modifies
# the attention metadata in directly), and therefore does not want to use
...
...
@@ -1951,7 +2244,12 @@ class GPUModelRunner(
spec_decode_common_attn_metadata
.
unpadded
(
num_tokens
,
num_reqs
)
)
return
attn_metadata
,
spec_decode_common_attn_metadata
return
(
attn_metadata
,
spec_decode_common_attn_metadata
,
scatter_indexes_tensor
,
gather_indexes_tensor
)
def
_compute_cascade_attn_prefix_lens
(
self
,
...
...
@@ -2803,9 +3101,20 @@ class GPUModelRunner(
return
model_runner_output
def
_pad_for_mla_cp
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
# if num_scheduled_tokens <= tp_size * tp_size:
# return num_scheduled_tokens
# else:
# return round_up(num_scheduled_tokens, tp_size)
return
round_up
(
num_scheduled_tokens
,
tp_size
)
def
_pad_for_sequence_parallelism
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
if
self
.
enable_lightly_cp
and
num_scheduled_tokens
>
self
.
lightly_cp_threshould
:
return
self
.
_pad_for_mla_cp
(
num_scheduled_tokens
)
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
self
.
compilation_config
.
pass_config
.
enable_sp
and
tp_size
>
1
:
return
round_up
(
num_scheduled_tokens
,
tp_size
)
...
...
@@ -3502,6 +3811,8 @@ class GPUModelRunner(
)
num_tokens_padded
=
batch_desc
.
num_tokens
if
self
.
enable_lightly_cp
and
num_tokens_unpadded
>
self
.
lightly_cp_threshould
:
num_tokens_padded
=
self
.
_pad_for_mla_cp
(
num_tokens_unpadded
)
num_reqs_padded
=
(
batch_desc
.
num_reqs
if
batch_desc
.
num_reqs
is
not
None
else
num_reqs
)
...
...
@@ -3558,8 +3869,12 @@ class GPUModelRunner(
ubatch_slices
=
ubatch_slices_padded
,
)
attn_metadata
,
spec_decode_common_attn_metadata
=
(
self
.
_build_attention_metadata
(
(
attn_metadata
,
spec_decode_common_attn_metadata
,
scatter_indexes_tensor
,
gather_indexes_tensor
,
)
=
self
.
_build_attention_metadata
(
num_tokens
=
num_tokens_unpadded
,
num_tokens_padded
=
num_tokens_padded
if
pad_attn
else
None
,
num_reqs
=
num_reqs
,
...
...
@@ -3572,7 +3887,6 @@ class GPUModelRunner(
cascade_attn_prefix_lens
=
cascade_attn_prefix_lens
,
slot_mappings
=
slot_mappings_by_group
,
)
)
(
input_ids
,
...
...
@@ -3614,6 +3928,10 @@ class GPUModelRunner(
ubatch_slices
=
ubatch_slices_padded
,
slot_mapping
=
slot_mappings
,
skip_compiled
=
has_encoder_input
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens_unpadded
>
self
.
lightly_cp_threshould
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
),
record_function_or_nullcontext
(
"gpu_model_runner: forward"
),
self
.
maybe_get_kv_connector_output
(
...
...
@@ -4105,7 +4423,16 @@ class GPUModelRunner(
spec_decode_metadata
,
valid_sampled_tokens_count
,
)
total_num_tokens
=
common_attn_metadata
.
num_actual_tokens
#total_num_tokens = common_attn_metadata.num_actual_tokens
if
(
self
.
enable_lightly_cp
and
common_attn_metadata
.
cp_common_metadata
is
not
None
):
total_num_tokens
=
(
common_attn_metadata
.
cp_common_metadata
.
num_actual_tokens
)
else
:
total_num_tokens
=
common_attn_metadata
.
num_actual_tokens
# When padding the batch, token_indices is just a range
target_token_ids
=
self
.
input_ids
.
gpu
[:
total_num_tokens
]
target_positions
=
self
.
_get_positions
(
total_num_tokens
)
...
...
@@ -4759,7 +5086,7 @@ class GPUModelRunner(
self
.
query_start_loc
.
copy_to_gpu
()
pad_attn
=
cudagraph_runtime_mode
==
CUDAGraphMode
.
FULL
attn_metadata
,
_
=
self
.
_build_attention_metadata
(
attn_metadata
,
_
,
_
,
_
=
self
.
_build_attention_metadata
(
num_tokens
=
num_tokens_unpadded
,
num_reqs
=
num_reqs_padded
,
max_query_len
=
max_query_len
,
...
...
@@ -4830,6 +5157,8 @@ class GPUModelRunner(
batch_descriptor
=
batch_desc
,
ubatch_slices
=
ubatch_slices_padded
,
slot_mapping
=
slot_mappings
,
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens_unpadded
>
self
.
lightly_cp_threshould
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
),
):
outputs
=
self
.
model
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment