Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ad510309
Unverified
Commit
ad510309
authored
Jul 30, 2025
by
Yong Hoon Shin
Committed by
GitHub
Jul 30, 2025
Browse files
Override attention metadata for fast prefill in some KV sharing setups (#21590)
Signed-off-by:
Yong Hoon Shin
<
yhshin@meta.com
>
parent
366f6b3a
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
287 additions
and
26 deletions
+287
-26
tests/v1/e2e/test_kv_sharing_fast_prefill.py
tests/v1/e2e/test_kv_sharing_fast_prefill.py
+143
-0
vllm/config.py
vllm/config.py
+15
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+6
-0
vllm/model_executor/models/gemma3n.py
vllm/model_executor/models/gemma3n.py
+1
-0
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+33
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+89
-24
No files found.
tests/v1/e2e/test_kv_sharing_fast_prefill.py
0 → 100644
View file @
ad510309
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
gc
import
random
from
typing
import
Optional
,
Union
import
pytest
import
torch
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
CompilationConfig
,
CompilationLevel
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.models.gemma3n
import
Gemma3nForConditionalGeneration
from
vllm.model_executor.models.registry
import
ModelRegistry
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.sequence
import
IntermediateTensors
from
...utils
import
fork_new_process_for_each_test
class
TestGemma3nForConditionalGeneration
(
Gemma3nForConditionalGeneration
):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
,
**
kwargs
)
attn_metadata
=
get_forward_context
().
attn_metadata
# attn_metadata is None during dummy runs
if
(
attn_metadata
is
not
None
and
self
.
cache_config
.
kv_sharing_fast_prefill
):
assert
isinstance
(
attn_metadata
,
dict
)
# true in V1
# Gemma3n-E2B has 30 layers, with last 20 layers being
# cross-decoder layers. Check attention metadata is correct
for
layer_name
,
metadata
in
attn_metadata
.
items
():
layer_idx
=
extract_layer_index
(
layer_name
)
if
layer_idx
>=
20
:
assert
hasattr
(
metadata
,
'logits_indices_padded'
)
assert
hasattr
(
metadata
,
'num_logits_indices'
)
else
:
assert
not
hasattr
(
metadata
,
'logits_indices_padded'
)
assert
not
hasattr
(
metadata
,
'num_logits_indices'
)
# Last layer will be a KV sharing layer
layer_attn_metadata
=
attn_metadata
[
self
.
model
.
language_model
.
layers
[
-
1
].
self_attn
.
attn
.
layer_name
]
logits_indices_padded
=
(
layer_attn_metadata
.
logits_indices_padded
)
assert
logits_indices_padded
is
not
None
num_logits_indices
=
layer_attn_metadata
.
num_logits_indices
assert
num_logits_indices
>
0
# Reset hidden states to random values and
# only set logits at logits_indices to valid values
# Because logits_indices are the only positions that are used
# for output token sampling, this still produces same outputs
logits_hs
=
hidden_states
[
logits_indices_padded
]
hidden_states
=
torch
.
randn_like
(
hidden_states
)
gen_indices
=
logits_indices_padded
[:
num_logits_indices
]
hidden_states
[
gen_indices
]
=
logits_hs
[:
num_logits_indices
]
return
hidden_states
@
pytest
.
fixture
def
test_prompts
():
"""
Adapted from tests/v1/e2e/test_spec_decode.py
"""
prompt_types
=
[
"repeat"
,
"sentence"
]
# Setting higher num prompts increases the chance of numerics mismatch
# due to matrix multiplication numerics depending on batch dimension
num_prompts
=
10
prompts
=
[]
random
.
seed
(
0
)
random_prompt_type_choices
=
random
.
choices
(
prompt_types
,
k
=
num_prompts
)
for
kind
in
random_prompt_type_choices
:
word_choices
=
[
"test"
,
"temp"
,
"hello"
,
"where"
]
word
=
random
.
choice
(
word_choices
)
if
kind
==
"repeat"
:
prompt
=
f
"""please repeat the word '
{
word
}
' 10 times."""
elif
kind
==
"sentence"
:
prompt
=
f
"""please give a ten-word sentence that
uses the word
{
word
}
at least once."""
else
:
raise
ValueError
(
f
"Unknown prompt type:
{
kind
}
"
)
prompts
.
append
(
prompt
)
return
prompts
@
fork_new_process_for_each_test
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
def
test_kv_sharing_fast_prefill
(
monkeypatch
:
pytest
.
MonkeyPatch
,
enforce_eager
:
bool
,
test_prompts
:
list
[
str
],
):
ModelRegistry
.
register_model
(
"Gemma3nForConditionalGeneration"
,
TestGemma3nForConditionalGeneration
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
100
)
compilation_config
=
CompilationConfig
(
# This allows vLLM compilation backend to handle allocating and
# managing buffers for cudagraph
cudagraph_copy_inputs
=
True
,
level
=
CompilationLevel
.
PIECEWISE
if
not
enforce_eager
else
CompilationLevel
.
NO_COMPILATION
)
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
"google/gemma-3n-E2B-it"
,
enforce_eager
=
enforce_eager
,
compilation_config
=
compilation_config
,
)
ref_responses
=
llm
.
generate
(
test_prompts
,
sampling_params
)
del
llm
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
llm
=
LLM
(
model
=
"google/gemma-3n-E2B-it"
,
enforce_eager
=
enforce_eager
,
compilation_config
=
compilation_config
,
kv_sharing_fast_prefill
=
True
)
optimized_responses
=
llm
.
generate
(
test_prompts
,
sampling_params
)
misses
=
0
for
ref_response
,
optimized_response
in
zip
(
ref_responses
,
optimized_responses
):
if
ref_response
.
outputs
[
0
].
text
!=
optimized_response
.
outputs
[
0
].
text
:
misses
+=
1
assert
misses
==
0
vllm/config.py
View file @
ad510309
...
...
@@ -1795,6 +1795,16 @@ class CacheConfig:
num_cpu_blocks
:
Optional
[
int
]
=
field
(
default
=
None
,
init
=
False
)
"""The number of blocks to allocate for CPU memory."""
kv_sharing_fast_prefill
:
bool
=
False
"""This feature is work in progress and no prefill optimization takes place
with this flag enabled currently.
In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
some layers can skip tokens corresponding to prefill. This flag enables
attention metadata for eligible layers to be overriden with metadata
necessary for implementating this optimization in some models (e.g. Gemma3n)
"""
def
compute_hash
(
self
)
->
str
:
"""
WARNING: Whenever a new field is added to this config,
...
...
@@ -1836,6 +1846,11 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got "
f
"
{
self
.
gpu_memory_utilization
}
."
)
if
self
.
kv_sharing_fast_prefill
:
logger
.
warning_once
(
"--kv-sharing-fast-prefill is currently work in progress "
"and not functional yet (i.e. no prefill savings)"
)
return
self
def
_verify_cache_dtype
(
self
)
->
None
:
...
...
vllm/engine/arg_utils.py
View file @
ad510309
...
...
@@ -445,6 +445,9 @@ class EngineArgs:
# DEPRECATED
enable_prompt_adapter
:
bool
=
False
kv_sharing_fast_prefill
:
bool
=
\
CacheConfig
.
kv_sharing_fast_prefill
def
__post_init__
(
self
):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
...
...
@@ -697,6 +700,8 @@ class EngineArgs:
**
cache_kwargs
[
"cpu_offload_gb"
])
cache_group
.
add_argument
(
"--calculate-kv-scales"
,
**
cache_kwargs
[
"calculate_kv_scales"
])
cache_group
.
add_argument
(
"--kv-sharing-fast-prefill"
,
**
cache_kwargs
[
"kv_sharing_fast_prefill"
])
# Multimodal related configs
multimodal_kwargs
=
get_kwargs
(
MultiModalConfig
)
...
...
@@ -1069,6 +1074,7 @@ class EngineArgs:
prefix_caching_hash_algo
=
self
.
prefix_caching_hash_algo
,
cpu_offload_gb
=
self
.
cpu_offload_gb
,
calculate_kv_scales
=
self
.
calculate_kv_scales
,
kv_sharing_fast_prefill
=
self
.
kv_sharing_fast_prefill
,
)
# Get the current placement group if Ray is initialized and
...
...
vllm/model_executor/models/gemma3n.py
View file @
ad510309
...
...
@@ -793,6 +793,7 @@ class Gemma3nForConditionalGeneration(nn.Module):
del
lora_config
# Unused.
super
().
__init__
()
self
.
config
=
config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
model
=
Gemma3nModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
logits_processor
=
LogitsProcessor
(
...
...
vllm/v1/attention/backends/utils.py
View file @
ad510309
...
...
@@ -3,8 +3,8 @@
import
abc
import
functools
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Generic
,
Optional
,
TypeVar
from
dataclasses
import
dataclass
,
make_dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Generic
,
Optional
,
TypeVar
import
numpy
as
np
import
torch
...
...
@@ -508,3 +508,34 @@ def reorder_batch_to_split_decodes_and_prefills(
modified_batch
=
True
return
modified_batch
KV_SHARING_FAST_PREFILL_METADATA_FIELDS
=
[
(
'logits_indices_padded'
,
Optional
[
torch
.
Tensor
],
None
),
(
'num_logits_indices'
,
int
,
0
),
]
def
subclass_attention_metadata
(
name_prefix
:
str
,
metadata_cls
:
Any
,
fields
:
list
[
tuple
[
str
,
Any
,
Any
]],
)
->
Any
:
"""
Return a new subclass of `metadata_cls` with additional fields
"""
name
:
str
=
name_prefix
+
metadata_cls
.
__name__
# type: ignore
Wrapped
=
make_dataclass
(
name
,
fields
,
bases
=
(
metadata_cls
,
))
return
Wrapped
def
make_kv_sharing_fast_prefill_attention_metadata
(
metadata_cls
:
Any
,
)
->
Any
:
"""
Return a new subclass of `metadata_cls` for fast prefill
"""
return
subclass_attention_metadata
(
name_prefix
=
"KVSharingFastPrefill"
,
metadata_cls
=
metadata_cls
,
fields
=
KV_SHARING_FAST_PREFILL_METADATA_FIELDS
,
)
vllm/v1/worker/gpu_model_runner.py
View file @
ad510309
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
gc
import
time
from
contextlib
import
contextmanager
...
...
@@ -47,6 +48,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
from
vllm.v1.attention.backends.mamba_selectors
import
get_mamba_attn_backend
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
make_kv_sharing_fast_prefill_attention_metadata
,
make_local_attention_virtual_batches
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
...
...
@@ -320,6 +322,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
self
.
kv_sharing_fast_prefill_eligible_layers
:
set
[
str
]
=
set
()
self
.
kv_sharing_fast_prefill_logits_indices
=
None
if
self
.
cache_config
.
kv_sharing_fast_prefill
:
self
.
kv_sharing_fast_prefill_logits_indices
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
...
...
@@ -735,6 +743,55 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_common_attn_metadata
=
None
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
if
not
use_spec_decode
:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices
=
query_start_loc
[
1
:]
-
1
spec_decode_metadata
=
None
else
:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens
=
np
.
zeros
(
num_reqs
,
dtype
=
np
.
int32
)
for
req_id
,
draft_token_ids
in
(
scheduler_output
.
scheduled_spec_decode_tokens
.
items
()):
req_idx
=
self
.
input_batch
.
req_id_to_index
[
req_id
]
num_draft_tokens
[
req_idx
]
=
len
(
draft_token_ids
)
spec_decode_metadata
=
self
.
_calc_spec_decode_metadata
(
num_draft_tokens
,
cu_num_tokens
)
logits_indices
=
spec_decode_metadata
.
logits_indices
logits_indices_padded
=
None
if
self
.
cache_config
.
kv_sharing_fast_prefill
:
assert
self
.
kv_sharing_fast_prefill_logits_indices
is
not
None
num_logits
=
logits_indices
.
shape
[
0
]
assert
num_logits
>
0
self
.
kv_sharing_fast_prefill_logits_indices
[:
num_logits
].
copy_
(
logits_indices
)
# There might have leftover indices in logits_indices[num_logits:]
# from previous iterations, whose values may be greater than the
# batch size in the current iteration. To ensure indices are always
# valid, we fill the padded indices with the last index.
self
.
kv_sharing_fast_prefill_logits_indices
[
num_logits
:].
fill_
(
logits_indices
[
-
1
].
item
())
if
(
self
.
use_cuda_graph
and
num_logits
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_logits_padded
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_logits
)
else
:
num_logits_padded
=
num_logits
logits_indices_padded
=
(
self
.
kv_sharing_fast_prefill_logits_indices
[:
num_logits_padded
]
)
attn_metadata
:
dict
[
str
,
Any
]
=
{}
# Prepare encoder attention metadata separately
...
...
@@ -806,7 +863,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
common_attn_metadata
=
common_attn_metadata
,
))
fast_prefill_metadata
=
attn_metadata_i
if
(
self
.
cache_config
.
kv_sharing_fast_prefill
and
self
.
kv_sharing_fast_prefill_eligible_layers
):
# Dynamically create a a dataclass type that inherits
# from attention metadata type but includes additional
# fields logits_indices_padded and num_logits_indices
# which are required for prefill truncation
fast_prefill_metadata_type
=
(
make_kv_sharing_fast_prefill_attention_metadata
(
metadata_cls
=
type
(
attn_metadata_i
),
))
fast_prefill_metadata
=
fast_prefill_metadata_type
(
**
dataclasses
.
asdict
(
attn_metadata_i
),
logits_indices_padded
=
logits_indices_padded
,
num_logits_indices
=
logits_indices
.
size
(
0
),
)
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
if
(
self
.
cache_config
.
kv_sharing_fast_prefill
and
layer_name
in
self
.
kv_sharing_fast_prefill_eligible_layers
):
attn_metadata
[
layer_name
]
=
fast_prefill_metadata
continue
attn_metadata
[
layer_name
]
=
attn_metadata_i
# Hack for now to fix chunked local attention + no hybrid kv cache
...
...
@@ -838,30 +916,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
b
.
can_run_in_cudagraph
(
common_attn_metadata
)
for
b
in
self
.
attn_metadata_builders
)
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
if
not
use_spec_decode
:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices
=
query_start_loc
[
1
:]
-
1
spec_decode_metadata
=
None
else
:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens
=
np
.
zeros
(
num_reqs
,
dtype
=
np
.
int32
)
for
req_id
,
draft_token_ids
in
(
scheduler_output
.
scheduled_spec_decode_tokens
.
items
()):
req_idx
=
self
.
input_batch
.
req_id_to_index
[
req_id
]
num_draft_tokens
[
req_idx
]
=
len
(
draft_token_ids
)
spec_decode_metadata
=
self
.
_calc_spec_decode_metadata
(
num_draft_tokens
,
cu_num_tokens
)
logits_indices
=
spec_decode_metadata
.
logits_indices
# Hot-Swap lora model
if
self
.
lora_config
:
self
.
set_active_loras
(
self
.
input_batch
,
num_scheduled_tokens
)
...
...
@@ -1433,6 +1487,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_metadata
,
num_scheduled_tokens_np
,
spec_decode_common_attn_metadata
)
=
(
self
.
_prepare_inputs
(
scheduler_output
))
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
...
...
@@ -2814,6 +2869,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config
.
kv_cache_groups
,
kv_caches
,
)
attn_layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
# Iterate in reversed order and add layers that re-use KV cache
# e.g. in YOCO-like KV sharing setups (e.g. Gemma3n)
for
layer_name
in
reversed
(
attn_layers
):
if
layer_name
in
self
.
shared_kv_cache_layers
:
self
.
kv_sharing_fast_prefill_eligible_layers
.
add
(
layer_name
)
else
:
break
bind_kv_cache
(
kv_caches
,
self
.
compilation_config
.
static_forward_context
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment