Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cba31c47
Unverified
Commit
cba31c47
authored
May 06, 2025
by
Chen Zhang
Committed by
GitHub
May 06, 2025
Browse files
[v1] AttentionMetadata for each layer (#17394)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
a6fed020
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
126 additions
and
46 deletions
+126
-46
vllm/attention/layer.py
vllm/attention/layer.py
+12
-3
vllm/forward_context.py
vllm/forward_context.py
+8
-3
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+5
-6
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+5
-5
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+5
-5
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+18
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+10
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+47
-21
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+16
-2
No files found.
vllm/attention/layer.py
View file @
cba31c47
...
...
@@ -210,6 +210,8 @@ class Attention(nn.Module):
if
self
.
use_direct_call
:
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
self
.
layer_name
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self
.
impl
.
forward
(
self
,
query
,
...
...
@@ -226,6 +228,8 @@ class Attention(nn.Module):
if
self
.
use_direct_call
:
forward_context
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
self
.
layer_name
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
return
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
self_kv_cache
,
attn_metadata
)
...
...
@@ -343,7 +347,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
attn_metadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
None
:
return
assert
isinstance
(
attn_metadata
,
dict
)
connector
.
wait_for_layer_load
(
layer_name
)
...
...
@@ -360,8 +364,9 @@ def maybe_save_kv_layer_to_connector(
attn_metadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
None
:
return
connector
.
save_kv_layer
(
layer_name
,
kv_cache_layer
,
attn_metadata
)
assert
isinstance
(
attn_metadata
,
dict
)
connector
.
save_kv_layer
(
layer_name
,
kv_cache_layer
,
attn_metadata
[
layer_name
])
def
unified_attention
(
...
...
@@ -374,6 +379,8 @@ def unified_attention(
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
output
=
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
kv_cache
,
...
...
@@ -411,6 +418,8 @@ def unified_attention_with_output(
wait_for_kv_layer_from_connector
(
layer_name
)
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self
.
impl
.
forward
(
self
,
...
...
vllm/forward_context.py
View file @
cba31c47
...
...
@@ -4,7 +4,7 @@ import time
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -38,8 +38,13 @@ class DPMetadata:
class
ForwardContext
:
# copy from vllm_config.compilation_config.static_forward_context
no_compile_layers
:
dict
[
str
,
Any
]
# TODO: extend to support per-layer dynamic forward context
attn_metadata
:
"AttentionMetadata"
# set dynamically for each forward pass
"""
Type AttentionMetadata for v0,
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
attention layer to its attention metadata
set dynamically for each forward pass
"""
attn_metadata
:
Union
[
"AttentionMetadata"
,
dict
[
str
,
"AttentionMetadata"
]]
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine
:
int
# set dynamically for each forward pass
# set dynamically for each forward pass
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
cba31c47
...
...
@@ -18,6 +18,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -309,13 +310,11 @@ class FlashAttentionMetadataBuilder:
return
False
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
):
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
max_seq_len
=
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
query_start_loc_cpu
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
]
query_start_loc
=
query_start_loc_cpu
.
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
seq_lens_cpu
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
]
seq_lens
=
seq_lens_cpu
.
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
block_table
=
(
self
.
runner
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
])
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
cba31c47
...
...
@@ -18,6 +18,7 @@ from vllm.config import (VllmConfig, get_current_vllm_config,
get_layers_from_vllm_config
)
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -394,16 +395,15 @@ class FlashInferMetadataBuilder:
)
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
):
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
assert
self
.
_num_decodes
+
self
.
_num_prefills
==
num_reqs
assert
(
self
.
_num_decode_tokens
+
self
.
_num_prefill_tokens
==
num_actual_tokens
)
page_size
=
self
.
runner
.
block_size
device
=
self
.
runner
.
device
qo_indptr
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
seq_lens
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
qo_indptr
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
block_table
=
(
self
.
runner
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
])
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
...
...
vllm/v1/attention/backends/mla/common.py
View file @
cba31c47
...
...
@@ -207,6 +207,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
,
round_down
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
@@ -451,7 +452,8 @@ class MLACommonMetadataBuilder(Generic[M]):
)
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
)
->
M
:
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
M
:
assert
self
.
_num_decodes
+
self
.
_num_prefills
==
num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
...
...
@@ -460,15 +462,13 @@ class MLACommonMetadataBuilder(Generic[M]):
device
=
self
.
runner
.
device
block_table
=
(
self
.
runner
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
])
query_start_loc
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
device
,
non_blocking
=
True
)
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
device
,
non_blocking
=
True
).
long
()
input_positions
=
self
.
runner
.
positions_cpu
[:
num_actual_tokens
].
to
(
device
,
non_blocking
=
True
).
long
()
seq_lens_cpu
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
]
seq_lens
=
seq_lens_cpu
.
to
(
device
,
non_blocking
=
True
)
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
prefill_metadata
=
None
if
self
.
_num_prefills
>
0
:
...
...
vllm/v1/attention/backends/utils.py
0 → 100644
View file @
cba31c47
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
torch
@
dataclass
class
CommonAttentionMetadata
:
"""
Attention metadata attributes that can be shared by layers in different KV
cache groups and thus having different block table.
"""
query_start_loc
:
torch
.
Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens
:
torch
.
Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
vllm/v1/spec_decode/eagle.py
View file @
cba31c47
...
...
@@ -2,7 +2,9 @@
import
torch
import
torch.nn
as
nn
from
vllm.config
import
CompilationLevel
,
VllmConfig
,
set_current_vllm_config
from
vllm.attention.layer
import
Attention
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
get_layers_from_vllm_config
,
set_current_vllm_config
)
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.loader
import
get_model_loader
...
...
@@ -276,6 +278,8 @@ class EagleProposer:
loader
=
get_model_loader
(
self
.
vllm_config
.
load_config
)
target_layer_num
=
self
.
vllm_config
.
model_config
.
get_num_layers
(
self
.
vllm_config
.
parallel_config
)
target_attn_layer_names
=
set
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
())
draft_model_config
=
\
self
.
vllm_config
.
speculative_config
.
draft_model_config
...
...
@@ -292,6 +296,11 @@ class EagleProposer:
vllm_config
=
self
.
vllm_config
,
start_layer_id
=
target_layer_num
).
to
(
target_device
)
draft_attn_layer_names
=
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
-
target_attn_layer_names
)
assert
len
(
draft_attn_layer_names
)
==
1
self
.
attn_layer_name
=
next
(
iter
(
draft_attn_layer_names
))
loaded_weights
=
self
.
model
.
load_weights
(
loader
.
get_all_weights
(
draft_model_config
,
self
.
model
))
if
self
.
vllm_config
.
speculative_config
.
method
==
"eagle3"
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
cba31c47
...
...
@@ -30,6 +30,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes
,
LayerBlockType
,
LazyLoader
,
cdiv
,
check_use_alibi
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
...
...
@@ -157,9 +158,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Sampler
self
.
sampler
=
Sampler
()
# Lazy initialization
# Lazy initialization
s
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
# self.kv_cache_config: KVCacheConfig
# req_id -> (input_id -> encoder_output)
self
.
encoder_cache
:
dict
[
str
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
...
...
@@ -488,7 +492,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_prepare_inputs
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
tuple
[
FlashAttentionMetadata
,
torch
.
Tensor
,
)
->
tuple
[
dict
[
str
,
FlashAttentionMetadata
]
,
torch
.
Tensor
,
Optional
[
SpecDecodeMetadata
]]:
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
...
...
@@ -585,6 +589,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
positions_cpu
[:
total_num_scheduled_tokens
],
non_blocking
=
True
)
query_start_loc
=
self
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
self
.
device
,
non_blocking
=
True
)
seq_lens
=
self
.
seq_lens_cpu
[:
num_reqs
].
to
(
self
.
device
,
non_blocking
=
True
)
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
)
attn_metadata
:
dict
[
str
,
FlashAttentionMetadata
]
=
{}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
# NOTE(Chen): there is exactly one KV cache group that contains all
# attetnion layers in the model for now, so the current logic for
# getting attn_metadata is not related to kv_cache_group information.
# Will extend this part to support multiple KV cache groups later.
for
kv_cache_group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
0
if
self
.
cascade_attn_enabled
:
...
...
@@ -593,12 +614,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output
.
num_common_prefix_blocks
,
)
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
attn_metadata
_i
=
self
.
attn_metadata_builder
.
build
(
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
common_prefix_len
=
common_prefix_len
,
)
common_attn_metadata
=
common_attn_metadata
)
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
...
...
@@ -608,7 +631,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices
=
attn_metadata
.
query_start_loc
[
1
:]
-
1
logits_indices
=
query_start_loc
[
1
:]
-
1
spec_decode_metadata
=
None
else
:
# Get the number of draft tokens for each request.
...
...
@@ -1230,6 +1253,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
next_token_ids
=
torch
.
tensor
(
next_token_ids
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
eagle_attn_metadata
=
attn_metadata
[
self
.
drafter
.
attn_layer_name
]
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
...
...
@@ -1241,8 +1265,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dim
=-
1
)
else
:
target_hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
target_slot_mapping
=
attn_metadata
.
slot_mapping
cu_num_tokens
=
attn_metadata
.
query_start_loc
target_slot_mapping
=
eagle_
attn_metadata
.
slot_mapping
cu_num_tokens
=
eagle_
attn_metadata
.
query_start_loc
else
:
# TODO(woosuk): Refactor this.
num_draft_tokens
=
spec_decode_metadata
.
num_draft_tokens
...
...
@@ -1256,7 +1280,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device
=
self
.
device
,
)
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
attn_metadata
.
query_start_loc
,
eagle_
attn_metadata
.
query_start_loc
,
num_rejected_tokens
,
)
target_token_ids
=
self
.
input_ids
[
token_indices
]
...
...
@@ -1266,7 +1290,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
[
h
[
token_indices
]
for
h
in
aux_hidden_states
],
dim
=-
1
)
else
:
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
attn_metadata
.
slot_mapping
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
draft_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
...
...
@@ -1275,7 +1300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_slot_mapping
=
target_slot_mapping
,
next_token_ids
=
next_token_ids
,
cu_num_tokens
=
cu_num_tokens
,
block_table
=
attn_metadata
.
block_table
,
block_table
=
eagle_
attn_metadata
.
block_table
,
sampling_metadata
=
sampling_metadata
,
)
spec_token_ids
=
draft_token_ids
.
tolist
()
...
...
@@ -1708,6 +1733,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise
NotImplementedError
(
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
self
.
kv_cache_config
=
kv_cache_config
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
cba31c47
...
...
@@ -588,7 +588,14 @@ class TPUModelRunner:
# Padded to avoid recompiling when `num_reqs` varies.
logits_indices
=
self
.
query_start_loc_cpu
[
1
:
padded_num_reqs
+
1
]
-
1
logits_indices
=
logits_indices
.
to
(
self
.
device
)
return
attn_metadata
,
logits_indices
,
padded_num_reqs
layer_names
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
per_layer_attn_metadata
=
{
layer_name
:
attn_metadata
for
layer_name
in
layer_names
}
return
per_layer_attn_metadata
,
logits_indices
,
padded_num_reqs
def
_scatter_placeholders
(
self
,
...
...
@@ -956,7 +963,14 @@ class TPUModelRunner:
torch
.
_dynamo
.
mark_dynamic
(
position_ids
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
0
)
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
layer_names
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
per_layer_attn_metadata
=
{
layer_name
:
attn_metadata
for
layer_name
in
layer_names
}
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
0
):
out
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
position_ids
,
inputs_embeds
=
inputs_embeds
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment