Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cb293f6a
Unverified
Commit
cb293f6a
authored
Aug 28, 2025
by
Yong Hoon Shin
Committed by
GitHub
Aug 28, 2025
Browse files
[V1] Enable prefill optimization for Gemma3n (#22628)
Signed-off-by:
Yong Hoon Shin
<
yhshin@meta.com
>
parent
7ffbf272
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
583 additions
and
228 deletions
+583
-228
tests/v1/e2e/test_kv_sharing_fast_prefill.py
tests/v1/e2e/test_kv_sharing_fast_prefill.py
+0
-57
vllm/config/cache.py
vllm/config/cache.py
+7
-5
vllm/model_executor/models/gemma3n.py
vllm/model_executor/models/gemma3n.py
+354
-65
vllm/model_executor/models/gemma3n_mm.py
vllm/model_executor/models/gemma3n_mm.py
+1
-1
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+126
-13
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+7
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+53
-43
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+28
-11
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+7
-33
No files found.
tests/v1/e2e/test_kv_sharing_fast_prefill.py
View file @
cb293f6a
...
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
from
typing
import
Optional
,
Union
import
pytest
import
torch
...
...
@@ -10,12 +9,6 @@ import torch
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
CompilationConfig
,
CompilationLevel
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.models.gemma3n_mm
import
(
Gemma3nForConditionalGeneration
)
from
vllm.model_executor.models.registry
import
ModelRegistry
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.sequence
import
IntermediateTensors
from
...utils
import
fork_new_process_for_each_test
...
...
@@ -23,54 +16,6 @@ from ...utils import fork_new_process_for_each_test
SEED
=
42
class
TestGemma3nForConditionalGeneration
(
Gemma3nForConditionalGeneration
):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
super
().
forward
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
,
**
kwargs
)
attn_metadata
=
get_forward_context
().
attn_metadata
# attn_metadata is None during dummy runs
if
(
attn_metadata
is
not
None
and
self
.
language_model
.
cache_config
.
kv_sharing_fast_prefill
):
assert
isinstance
(
attn_metadata
,
dict
)
# true in V1
# Gemma3n-E2B has 30 layers, with last 20 layers being
# cross-decoder layers. Check attention metadata is correct
for
layer_name
,
metadata
in
attn_metadata
.
items
():
layer_idx
=
extract_layer_index
(
layer_name
)
if
layer_idx
>=
20
:
assert
hasattr
(
metadata
,
'logits_indices_padded'
)
assert
hasattr
(
metadata
,
'num_logits_indices'
)
else
:
assert
not
hasattr
(
metadata
,
'logits_indices_padded'
)
assert
not
hasattr
(
metadata
,
'num_logits_indices'
)
# Last layer will be a KV sharing layer
layer_attn_metadata
=
attn_metadata
[
self
.
language_model
.
model
.
layers
[
-
1
].
self_attn
.
attn
.
layer_name
]
logits_indices_padded
=
(
layer_attn_metadata
.
logits_indices_padded
)
assert
logits_indices_padded
is
not
None
num_logits_indices
=
layer_attn_metadata
.
num_logits_indices
assert
num_logits_indices
>
0
# Reset hidden states to random values and
# only set logits at logits_indices to valid values
# Because logits_indices are the only positions that are used
# for output token sampling, this still produces same outputs
logits_hs
=
hidden_states
[
logits_indices_padded
]
hidden_states
=
torch
.
randn_like
(
hidden_states
)
gen_indices
=
logits_indices_padded
[:
num_logits_indices
]
hidden_states
[
gen_indices
]
=
logits_hs
[:
num_logits_indices
]
return
hidden_states
@
pytest
.
fixture
def
test_prompts
():
"""
...
...
@@ -124,8 +69,6 @@ def test_kv_sharing_fast_prefill(
enforce_eager
:
bool
,
test_prompts
:
list
[
str
],
):
ModelRegistry
.
register_model
(
"Gemma3nForConditionalGeneration"
,
TestGemma3nForConditionalGeneration
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
100
)
compilation_config
=
CompilationConfig
(
# This allows vLLM compilation backend to handle allocating and
...
...
vllm/config/cache.py
View file @
cb293f6a
...
...
@@ -145,12 +145,19 @@ class CacheConfig:
self
.
_verify_cache_dtype
()
self
.
_verify_prefix_caching
()
self
.
_verify_kv_sharing_fast_prefill
()
def
metrics_info
(
self
):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
return
{
key
:
str
(
value
)
for
key
,
value
in
self
.
__dict__
.
items
()}
def
_verify_kv_sharing_fast_prefill
(
self
)
->
None
:
if
self
.
kv_sharing_fast_prefill
and
not
envs
.
VLLM_USE_V1
:
raise
NotImplementedError
(
"Fast prefill optimization for KV sharing is not supported "
"in V0 currently."
)
@
model_validator
(
mode
=
'after'
)
def
_verify_args
(
self
)
->
Self
:
if
self
.
cpu_offload_gb
<
0
:
...
...
@@ -162,11 +169,6 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got "
f
"
{
self
.
gpu_memory_utilization
}
."
)
if
self
.
kv_sharing_fast_prefill
:
logger
.
warning_once
(
"--kv-sharing-fast-prefill is currently work in progress "
"and not functional yet (i.e. no prefill savings)"
)
return
self
def
_verify_cache_dtype
(
self
)
->
None
:
...
...
vllm/model_executor/models/gemma3n.py
View file @
cb293f6a
...
...
@@ -23,9 +23,11 @@ from torch import nn
from
transformers.models.gemma3n.configuration_gemma3n
import
Gemma3nTextConfig
from
vllm.attention
import
Attention
from
vllm.compilation.backends
import
set_model_tag
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
(
_ACTIVATION_REGISTRY
,
GeluAndMul
,
...
...
@@ -45,6 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.attention.backends.utils
import
KVSharingFastPrefillMetadata
from
.interfaces
import
SupportsQuant
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
...
...
@@ -533,7 +536,178 @@ class Gemma3nDecoderLayer(nn.Module):
return
corrected_predictions
@
support_torch_compile
# This enables torch.compile if --kv-sharing-fast-prefill passed
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma3nSelfDecoder
(
nn
.
Module
):
"""
Includes altup embedding and self decoder layers
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma3nDecoderLayer
],
layer_idx_start
:
int
,
per_layer_model_projection
:
ColumnParallelLinear
,
embed_scale_per_layer
:
torch
.
Tensor
,
embed_tokens_per_layer
:
VocabParallelEmbedding
,
per_layer_projection_norm
:
RMSNorm
,
per_layer_input_scale
:
torch
.
Tensor
,
altup_projections
:
nn
.
ModuleList
,
eps
:
torch
.
Tensor
,
embed_tokens
:
VocabParallelEmbedding
,
embed_scale
:
torch
.
Tensor
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
self
.
per_layer_model_projection
=
per_layer_model_projection
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
embed_scale_per_layer
=
embed_scale_per_layer
self
.
embed_tokens_per_layer
=
embed_tokens_per_layer
self
.
per_layer_projection_norm
=
per_layer_projection_norm
self
.
per_layer_input_scale
=
per_layer_input_scale
self
.
altup_projections
=
altup_projections
self
.
eps
=
eps
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
embed_scale
def
get_per_layer_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Deal with the fact that vocab_size_per_layer_input < vocab_size
# which causes us to have some out of vocab tokens by setting
# those token ids to 0. This matches the HF implementation.
per_layer_inputs_mask
=
torch
.
logical_and
(
input_ids
>=
0
,
input_ids
<
self
.
config
.
vocab_size_per_layer_input
)
per_layer_inputs_tokens
=
torch
.
where
(
per_layer_inputs_mask
,
input_ids
,
torch
.
zeros_like
(
input_ids
))
return
self
.
embed_tokens_per_layer
(
per_layer_inputs_tokens
)
*
self
.
embed_scale_per_layer
def
get_per_layer_inputs
(
self
,
hidden_states_0
:
torch
.
Tensor
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
per_layer_projection
=
self
.
per_layer_model_projection
(
hidden_states_0
)
per_layer_projection
=
per_layer_projection
.
reshape
(
*
hidden_states_0
.
shape
[:
-
1
],
self
.
config
.
num_hidden_layers
,
self
.
config
.
hidden_size_per_layer_input
,
)
per_layer_projection
=
self
.
per_layer_projection_norm
(
per_layer_projection
)
if
per_layer_inputs
is
not
None
:
# Profiling run does not compute per_layer_inputs
per_layer_inputs
=
per_layer_projection
+
per_layer_inputs
per_layer_inputs
*=
self
.
per_layer_input_scale
else
:
per_layer_inputs
=
per_layer_projection
return
per_layer_inputs
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
def
altup_embed
(
self
,
hidden_states_0
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Altup embed.
hidden_states
=
[
hidden_states_0
]
*
self
.
config
.
altup_num_inputs
target_magnitude
=
torch
.
mean
(
hidden_states_0
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
for
i
in
range
(
1
,
self
.
config
.
altup_num_inputs
):
hidden_states
[
i
]
=
self
.
altup_projections
[
i
-
1
](
hidden_states
[
i
])
new_magnitude
=
torch
.
mean
(
hidden_states
[
i
]
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
hidden_states
[
i
]
*=
target_magnitude
/
torch
.
maximum
(
new_magnitude
,
self
.
eps
)
hidden_states
=
torch
.
stack
(
hidden_states
,
dim
=-
1
)
return
hidden_states
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
inputs_embeds
is
not
None
:
hidden_states_0
=
inputs_embeds
else
:
hidden_states_0
=
self
.
get_input_embeddings
(
input_ids
)
adjusted_per_layer_inputs
=
self
.
get_per_layer_inputs
(
hidden_states_0
,
per_layer_inputs
)
hidden_states
=
self
.
altup_embed
(
hidden_states_0
)
# [altnum_inputs, num_tokens, hidden_size]
hidden_states
=
hidden_states
.
permute
(
2
,
0
,
1
)
for
idx
,
layer
in
enumerate
(
self
.
decoder_layers
):
layer_idx
=
idx
+
self
.
layer_idx_start
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
per_layer_input
=
adjusted_per_layer_inputs
[:,
layer_idx
,
:],
**
kwargs
,
)
# [num_tokens, hidden_size, altnum_inputs]
hidden_states
=
hidden_states
.
permute
(
1
,
2
,
0
)
return
hidden_states
,
adjusted_per_layer_inputs
# This enables torch.compile if --kv-sharing-fast-prefill passed
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma3nCrossDecoder
(
nn
.
Module
):
"""
Cross-decoder layers
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma3nDecoderLayer
],
layer_idx_start
:
int
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
# [altnum_inputs, num_tokens, hidden_size]
hidden_states
=
hidden_states
.
permute
(
2
,
0
,
1
)
for
idx
,
layer
in
enumerate
(
self
.
decoder_layers
):
layer_idx
=
idx
+
self
.
layer_idx_start
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
per_layer_input
=
per_layer_inputs
[:,
layer_idx
,
:],
**
kwargs
,
)
# [num_tokens, hidden_size, altnum_inputs]
hidden_states
=
hidden_states
.
permute
(
1
,
2
,
0
)
return
hidden_states
# This disables torch.compile if --kv-sharing-fast-prefill passed
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
not
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma3nTextModel
(
nn
.
Module
,
SupportsQuant
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -543,7 +717,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
...
...
@@ -613,95 +786,211 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
lambda
prefix
:
Gemma3nDecoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
eps
=
torch
.
tensor
(
torch
.
finfo
().
min
)
first_kv_shared_layer_idx
=
(
config
.
num_hidden_layers
-
config
.
num_kv_shared_layers
)
# Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
with
set_model_tag
(
"self_decoder"
):
self
.
self_decoder
=
Gemma3nSelfDecoder
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.self_decoder"
,
decoder_layers
=
self
.
layers
[:
first_kv_shared_layer_idx
],
layer_idx_start
=
0
,
per_layer_model_projection
=
self
.
per_layer_model_projection
,
embed_scale_per_layer
=
self
.
embed_scale_per_layer
,
embed_tokens_per_layer
=
self
.
embed_tokens_per_layer
,
per_layer_projection_norm
=
self
.
per_layer_projection_norm
,
per_layer_input_scale
=
self
.
per_layer_input_scale
,
altup_projections
=
self
.
altup_projections
,
eps
=
self
.
eps
,
embed_tokens
=
self
.
embed_tokens
,
embed_scale
=
self
.
embed_scale
,
)
# Layer idx 20-30 are cross-decoder layers in YOCO
with
set_model_tag
(
"cross_decoder"
):
self
.
cross_decoder
=
Gemma3nCrossDecoder
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.cross_decoder"
,
decoder_layers
=
self
.
layers
[
first_kv_shared_layer_idx
:],
layer_idx_start
=
first_kv_shared_layer_idx
,
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
)
self
.
eps
=
torch
.
tensor
(
torch
.
finfo
().
min
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
self
.
fast_prefill_enabled
=
cache_config
.
kv_sharing_fast_prefill
if
self
.
fast_prefill_enabled
:
# Allocate static buffers for CUDAGraph
# TODO(sarckk): Extract this functionality to interface
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
device
=
next
(
self
.
parameters
()).
device
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
hidden_states
=
torch
.
zeros
(
(
max_num_tokens
,
config
.
hidden_size
,
self
.
config
.
altup_num_inputs
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
self
.
per_layer_inputs
=
torch
.
zeros
(
(
max_num_tokens
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
hidden_size_per_layer_input
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
def
get_per_layer_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Deal with the fact that vocab_size_per_layer_input < vocab_size
# which causes us to have some out of vocab tokens by setting
# those token ids to 0. This matches the HF implementation.
per_layer_inputs_mask
=
torch
.
logical_and
(
input_ids
>=
0
,
input_ids
<
self
.
config
.
vocab_size_per_layer_input
)
per_layer_inputs_tokens
=
torch
.
where
(
per_layer_inputs_mask
,
input_ids
,
torch
.
zeros_like
(
input_ids
))
return
self
.
embed_tokens_per_layer
(
per_layer_inputs_tokens
)
*
self
.
embed_scale_per_layer
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
self_decoder
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
fast_prefill_
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
inputs_embeds
is
not
None
:
hidden_states_0
=
inputs_embeds
else
:
hidden_states_0
=
self
.
get_input_embeddings
(
input_ids
)
)
->
torch
.
Tensor
:
logits_indices_padded
,
num_logits_indices
=
None
,
None
attn_metadata
=
get_forward_context
().
attn_metadata
# attn_metadata is None during dummy runs
if
(
self
.
fast_prefill_enabled
and
attn_metadata
is
not
None
):
assert
isinstance
(
attn_metadata
,
dict
)
# Last layer is a KV sharing layer
layer_attn_metadata
=
attn_metadata
[
self
.
layers
[
-
1
].
self_attn
.
attn
.
layer_name
]
if
(
isinstance
(
layer_attn_metadata
,
KVSharingFastPrefillMetadata
)):
logits_indices_padded
=
(
layer_attn_metadata
.
logits_indices_padded
)
num_logits_indices
=
layer_attn_metadata
.
num_logits_indices
# Copy inputs for cudagraph
batch_size
=
positions
.
size
(
0
)
self
.
positions
[:
batch_size
].
copy_
(
positions
)
self_decoder_hidden_states
,
per_layer_inputs_adjusted
=
\
self
.
self_decoder
(
input_ids
=
input_ids
,
positions
=
self
.
positions
[:
batch_size
],
inputs_embeds
=
inputs_embeds
,
per_layer_inputs
=
per_layer_inputs
,
**
kwargs
,
)
per_layer_projection
=
self
.
per_layer_model_projection
(
hidden_states_0
)
per_layer_projection
=
per_layer_projection
.
reshape
(
*
hidden_states_0
.
shape
[:
-
1
],
self
.
config
.
num_hidden_layers
,
self
.
config
.
hidden_size_per_layer_input
,
if
logits_indices_padded
is
None
:
logits_indices_padded
=
torch
.
arange
(
positions
.
size
(
0
),
dtype
=
positions
.
dtype
,
device
=
positions
.
device
,
)
# NOTE(sarckk): There is currently a bug caused by
# vLLM converting output of last piecewise CUDA graph
# to weakref, causing memory to be prematurely freed
# when there are multiple compilation units
# Keep .clone() until fix in
# https://github.com/vllm-project/vllm/pull/22282
hidden_states
=
self_decoder_hidden_states
.
clone
()
# Copy inputs for cudagraph
num_padded_logits_indices
=
logits_indices_padded
.
size
(
0
)
self
.
positions
[:
num_padded_logits_indices
].
copy_
(
positions
[
logits_indices_padded
])
self
.
hidden_states
[:
num_padded_logits_indices
].
copy_
(
self_decoder_hidden_states
[
logits_indices_padded
])
self
.
per_layer_inputs
[:
num_padded_logits_indices
].
copy_
(
per_layer_inputs_adjusted
[
logits_indices_padded
])
cross_decoder_hidden_states
=
self
.
cross_decoder
(
positions
=
self
.
positions
[:
num_padded_logits_indices
],
hidden_states
=
self
.
hidden_states
[:
num_padded_logits_indices
],
per_layer_inputs
=
self
.
per_layer_inputs
[:
num_padded_logits_indices
],
**
kwargs
,
)
per_layer_projection
=
self
.
per_layer_projection_norm
(
per_layer_projection
)
if
per_layer_inputs
is
not
None
:
# Profiling run does not compute per_layer_inputs
per_layer_inputs
=
per_layer_projection
+
per_layer_inputs
per_layer_inputs
*=
self
.
per_layer_input_scale
if
num_logits_indices
is
not
None
:
assert
num_logits_indices
>
0
# Merge cross-decoder and self-decoder hidden states
hidden_states
[
logits_indices_padded
[:
num_logits_indices
]]
=
(
cross_decoder_hidden_states
[:
num_logits_indices
])
else
:
per_layer_inputs
=
per_layer_projection
hidden_states
=
cross_decoder_hidden_states
# Altup embed.
hidden_states
=
[
hidden_states_0
]
*
self
.
config
.
altup_num_inputs
target_magnitude
=
torch
.
mean
(
hidden_states_0
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
for
i
in
range
(
1
,
self
.
config
.
altup_num_inputs
):
hidden_states
[
i
]
=
self
.
altup_projections
[
i
-
1
](
hidden_states
[
i
])
new_magnitude
=
torch
.
mean
(
hidden_states
[
i
]
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
hidden_states
[
i
]
*=
target_magnitude
/
torch
.
maximum
(
new_magnitude
,
self
.
eps
)
hidden_states
=
torch
.
stack
(
hidden_states
,
dim
=
0
)
return
hidden_states
# Transformer blocks.
for
layer_idx
,
layer
in
enumerate
(
self
.
layers
):
# [altup_num_inputs, num_tokens, hidden_size]
hidden_states
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
per_layer_input
=
per_layer_inputs
[:,
layer_idx
,
:],
**
kwargs
,
)
def
normal_forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
hidden_states
,
per_layer_inputs
=
self
.
self_decoder
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
per_layer_inputs
=
per_layer_inputs
,
**
kwargs
,
)
hidden_states
=
self
.
cross_decoder
(
positions
=
positions
,
hidden_states
=
hidden_states
,
per_layer_inputs
=
per_layer_inputs
,
**
kwargs
,
)
return
hidden_states
def
altup_unembed
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Altup unembed.
target_magnitude
=
torch
.
mean
(
hidden_states
[
0
]
**
2
,
target_magnitude
=
torch
.
mean
(
hidden_states
[
...,
0
]
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
for
i
in
range
(
1
,
self
.
config
.
altup_num_inputs
):
hidden_states
[
i
]
=
self
.
altup_unembed_projections
[
i
-
1
](
hidden_states
[
i
])
new_magnitude
=
torch
.
mean
(
hidden_states
[
i
]
**
2
,
hidden_states
[
...,
i
]
=
self
.
altup_unembed_projections
[
i
-
1
](
hidden_states
[
...,
i
])
new_magnitude
=
torch
.
mean
(
hidden_states
[
...,
i
]
**
2
,
dim
=-
1
,
keepdim
=
True
)
**
0.5
hidden_states
[
i
]
*=
target_magnitude
/
torch
.
maximum
(
hidden_states
[
...,
i
]
*=
target_magnitude
/
torch
.
maximum
(
new_magnitude
,
self
.
eps
)
# [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
hidden_states
=
torch
.
mean
(
hidden_states
,
dim
=
0
)
# [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
hidden_states
=
torch
.
mean
(
hidden_states
,
dim
=-
1
)
return
hidden_states
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
per_layer_inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
self
.
fast_prefill_enabled
:
hidden_states
=
self
.
fast_prefill_forward
(
input_ids
,
positions
,
inputs_embeds
,
per_layer_inputs
,
**
kwargs
,
)
else
:
hidden_states
=
self
.
normal_forward
(
input_ids
,
positions
,
inputs_embeds
,
per_layer_inputs
,
**
kwargs
,
)
hidden_states
=
self
.
altup_unembed
(
hidden_states
)
return
self
.
norm
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
...
...
vllm/model_executor/models/gemma3n_mm.py
View file @
cb293f6a
...
...
@@ -620,7 +620,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
# them here, as the model forward has only access to the input_embeds.
if
input_ids
is
not
None
:
per_layer_inputs
=
self
.
language_model
.
model
.
get_per_layer_input_embeddings
(
per_layer_inputs
=
self
.
language_model
.
model
.
self_decoder
.
get_per_layer_input_embeddings
(
input_ids
)
per_layer_inputs
=
per_layer_inputs
.
reshape
(
-
1
,
self
.
config
.
text_config
.
num_hidden_layers
,
...
...
vllm/v1/attention/backends/utils.py
View file @
cb293f6a
...
...
@@ -4,11 +4,13 @@ import abc
import
enum
import
functools
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
,
make_dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Generic
,
Optional
,
TypeVar
from
dataclasses
import
dataclass
,
fields
,
make_dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Generic
,
Optional
,
Protocol
,
TypeVar
)
import
numpy
as
np
import
torch
from
typing_extensions
import
runtime_checkable
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.utils
import
cdiv
...
...
@@ -19,7 +21,8 @@ if TYPE_CHECKING:
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
)
from
vllm.attention.layer
import
Attention
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
get_kv_connector_cache_layout
)
...
...
@@ -65,6 +68,10 @@ class CommonAttentionMetadata:
causal
:
bool
=
True
# Needed by FastPrefillAttentionBuilder
logits_indices_padded
:
Optional
[
torch
.
Tensor
]
=
None
num_logits_indices
:
Optional
[
int
]
=
None
@
dataclass
class
UbatchSlice
:
...
...
@@ -542,6 +549,69 @@ def make_local_attention_virtual_batches(
)
def
make_kv_sharing_fast_prefill_common_attn_metadata
(
common_attn_metadata
:
CommonAttentionMetadata
,
)
->
CommonAttentionMetadata
:
if
common_attn_metadata
.
max_query_len
==
1
:
# All requests are decode (assume 1 token for now)
# Skip computing fast prefill path
return
common_attn_metadata
assert
common_attn_metadata
.
logits_indices_padded
is
not
None
assert
common_attn_metadata
.
num_logits_indices
is
not
None
logits_indices_padded
=
common_attn_metadata
.
logits_indices_padded
num_logits_indices
=
common_attn_metadata
.
num_logits_indices
# Get rid of CUDAGraph padding, if any
logits_indices
=
logits_indices_padded
[:
num_logits_indices
]
num_reqs
=
common_attn_metadata
.
num_reqs
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
# Example inputs
# num_reqs: 3
# generation_indices: [14, 18, 19, 27]
# query_start_loc: [0, 15, 20, 28]
# seq_lens: [41, 31, 40]
# Find how many decode indices belong to each request
# request_ids: [0, 1, 1, 2]
request_ids
=
torch
.
bucketize
(
logits_indices
,
query_start_loc
[
1
:],
right
=
True
)
# Figure out how many tokens are in each request
# num_decode_tokens: [1, 2, 1]
num_decode_tokens
=
torch
.
bincount
(
request_ids
,
minlength
=
num_reqs
)
# Calculate new query_start_loc with tokens in generation_indices
# decode_query_start_loc: [0, 1, 3, 4]
decode_query_start_loc
=
torch
.
empty
(
num_reqs
+
1
,
device
=
query_start_loc
.
device
,
dtype
=
query_start_loc
.
dtype
)
decode_query_start_loc
[
0
]
=
0
decode_query_start_loc
[
1
:]
=
torch
.
cumsum
(
num_decode_tokens
,
dim
=
0
)
decode_max_query_len
=
int
(
num_decode_tokens
.
max
().
item
())
total_num_decode_tokens
=
int
(
num_decode_tokens
.
sum
().
item
())
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
decode_query_start_loc
,
query_start_loc_cpu
=
decode_query_start_loc
.
to
(
"cpu"
,
non_blocking
=
True
),
seq_lens
=
seq_lens
,
seq_lens_cpu
=
seq_lens
.
to
(
"cpu"
,
non_blocking
=
True
),
num_computed_tokens_cpu
=
common_attn_metadata
.
num_computed_tokens_cpu
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_decode_tokens
,
max_query_len
=
decode_max_query_len
,
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
causal
=
True
,
)
return
common_attn_metadata
def
subclass_attention_backend
(
name_prefix
:
str
,
attention_backend_cls
:
type
[
AttentionBackend
],
builder_cls
:
type
[
AttentionMetadataBuilder
[
M
]]
...
...
@@ -679,13 +749,56 @@ def subclass_attention_metadata(
return
Wrapped
def
make_kv_sharing_fast_prefill_attention_metadata
(
metadata_cls
:
Any
,
)
->
Any
:
"""
Return a new subclass of `metadata_cls` for fast prefill
"""
return
subclass_attention_metadata
(
name_prefix
=
"KVSharingFastPrefill"
,
metadata_cls
=
metadata_cls
,
fields
=
KV_SHARING_FAST_PREFILL_METADATA_FIELDS
,
)
@
runtime_checkable
class
KVSharingFastPrefillMetadata
(
Protocol
):
logits_indices_padded
:
torch
.
Tensor
num_logits_indices
:
int
def
create_fast_prefill_custom_backend
(
prefix
:
str
,
underlying_attn_backend
:
AttentionBackend
,
)
->
type
[
AttentionBackend
]:
underlying_builder
=
underlying_attn_backend
.
get_builder_cls
()
class
FastPrefillAttentionBuilder
(
underlying_builder
):
# type: ignore
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
)
->
AttentionMetadata
:
new_common_attn_metadata
=
\
make_kv_sharing_fast_prefill_common_attn_metadata
(
common_attn_metadata
)
metadata
=
super
().
build
(
common_prefix_len
,
new_common_attn_metadata
,
fast_build
)
class
KVSharingFastPrefillAttentionMetadata
(
metadata
.
__class__
,
# type: ignore
KVSharingFastPrefillMetadata
):
def
__init__
(
self
,
metadata
,
common_attn_metadata
):
# Shallow copy all fields in metadata cls
for
field
in
fields
(
metadata
.
__class__
):
setattr
(
self
,
field
.
name
,
getattr
(
metadata
,
field
.
name
))
# Set additional fields that will be used in model code
assert
(
common_attn_metadata
.
logits_indices_padded
is
not
None
and
common_attn_metadata
.
num_logits_indices
is
not
None
)
self
.
logits_indices_padded
=
\
common_attn_metadata
.
logits_indices_padded
self
.
num_logits_indices
=
\
common_attn_metadata
.
num_logits_indices
return
KVSharingFastPrefillAttentionMetadata
(
metadata
,
common_attn_metadata
)
attn_backend
=
subclass_attention_backend
(
name_prefix
=
prefix
,
attention_backend_cls
=
underlying_attn_backend
,
builder_cls
=
FastPrefillAttentionBuilder
)
return
attn_backend
vllm/v1/engine/async_llm.py
View file @
cb293f6a
...
...
@@ -335,6 +335,13 @@ class AsyncLLM(EngineClient):
returning the RequestOutput back to the caller.
"""
if
(
self
.
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
and
sampling_params
.
prompt_logprobs
):
raise
ValueError
(
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, please disable it when the requests need "
"prompt logprobs"
)
try
:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
cb293f6a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
gc
import
itertools
import
time
...
...
@@ -58,7 +57,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
supports_dynamo
)
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
make_kv_sharing_fast_prefill_attention_metadata
,
create_fast_prefill_custom_backend
,
reorder_batch_to_split_decodes_and_prefills
)
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
...
...
@@ -84,9 +83,10 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin
,
KVConnectorOutput
)
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
.utils
import
(
AttentionGroup
,
MultiModalBudget
,
bind_kv_cache
,
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
from
.utils
import
(
AttentionGroup
,
MultiModalBudget
,
add_kv_sharing_layers_to_kv_cache_groups
,
bind_kv_cache
,
gather_mm_placeholders
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
...
...
@@ -860,6 +860,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_seq_len
=
max_seq_len
,
block_table_tensor
=
blk_table_tensor
,
slot_mapping
=
slot_mapping
,
logits_indices_padded
=
logits_indices_padded
,
num_logits_indices
=
logits_indices
.
size
(
0
),
causal
=
True
,
)
...
...
@@ -884,28 +886,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
common_attn_metadata
=
common_attn_metadata
,
))
fast_prefill_metadata
=
attn_metadata_i
if
(
self
.
cache_config
.
kv_sharing_fast_prefill
and
self
.
kv_sharing_fast_prefill_eligible_layers
):
# Dynamically create a a dataclass type that inherits
# from attention metadata type but includes additional
# fields logits_indices_padded and num_logits_indices
# which are required for prefill truncation
fast_prefill_metadata_type
=
(
make_kv_sharing_fast_prefill_attention_metadata
(
metadata_cls
=
type
(
attn_metadata_i
),
))
fast_prefill_metadata
=
fast_prefill_metadata_type
(
**
dataclasses
.
asdict
(
attn_metadata_i
),
logits_indices_padded
=
logits_indices_padded
,
num_logits_indices
=
logits_indices
.
size
(
0
),
)
for
layer_name
in
attn_group
.
layer_names
:
if
(
self
.
cache_config
.
kv_sharing_fast_prefill
and
layer_name
in
self
.
kv_sharing_fast_prefill_eligible_layers
):
attn_metadata
[
layer_name
]
=
fast_prefill_metadata
continue
attn_metadata
[
layer_name
]
=
attn_metadata_i
# Hot-Swap lora model
...
...
@@ -1484,6 +1465,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
self
.
kv_connector_no_forward
(
scheduler_output
,
self
.
vllm_config
)
if
self
.
cache_config
.
kv_sharing_fast_prefill
:
assert
not
self
.
input_batch
.
num_prompt_logprobs
,
(
"--kv-sharing-fast-prefill produces incorrect logprobs for "
"prompt tokens, tokens, please disable it when the requests "
"need prompt logprobs"
)
# Prepare the decoder inputs.
(
attn_metadata
,
logits_indices
,
spec_decode_metadata
,
num_scheduled_tokens_np
,
spec_decode_common_attn_metadata
,
...
...
@@ -2742,6 +2729,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# layer.
for
layer_name
in
layer_names
:
attn_backend
=
layers
[
layer_name
].
get_attn_backend
()
if
layer_name
in
self
.
kv_sharing_fast_prefill_eligible_layers
:
attn_backend
=
create_fast_prefill_custom_backend
(
"FastPrefill"
,
attn_backend
,
)
key
=
attn_backend
.
full_cls_name
()
attn_backends
[
key
]
=
attn_backend
attn_backend_layers
[
key
].
append
(
layer_name
)
...
...
@@ -3074,20 +3068,40 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_caches
=
self
.
_reshape_kv_cache_tensors
(
kv_cache_config
,
kv_cache_raw_tensors
)
# Setup `kv_cache_config` and `kv_caches` for models
# with cross-layer KV sharing
if
self
.
shared_kv_cache_layers
:
initialize_kv_cache_for_kv_sharing
(
self
.
shared_kv_cache_layers
,
kv_cache_config
.
kv_cache_groups
,
kv_caches
,
self
.
attn_groups
,
self
.
runner_only_attn_layers
,
)
# Set up cross-layer KV cache sharing
for
layer_name
,
target_layer_name
in
self
.
shared_kv_cache_layers
.
items
(
):
logger
.
debug
(
"%s reuses KV cache of %s"
,
layer_name
,
target_layer_name
)
kv_caches
[
layer_name
]
=
kv_caches
[
target_layer_name
]
bind_kv_cache
(
kv_caches
,
self
.
compilation_config
.
static_forward_context
,
self
.
kv_caches
)
return
kv_caches
def
maybe_add_kv_sharing_layers_to_kv_cache_groups
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Add layers that re-use KV cache to KV cache group of its target layer.
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
"""
if
not
self
.
shared_kv_cache_layers
:
# No cross-layer KV sharing, return
return
add_kv_sharing_layers_to_kv_cache_groups
(
self
.
shared_kv_cache_layers
,
kv_cache_config
.
kv_cache_groups
,
self
.
runner_only_attn_layers
,
)
if
self
.
cache_config
.
kv_sharing_fast_prefill
:
# In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
# similar KV sharing setups, only the layers that generate KV caches
# are involved in the prefill phase, enabling prefill to early exit.
attn_layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
# Iterate in reversed order and add layers that re-use KV cache
# e.g. in YOCO-like KV sharing setups (e.g. Gemma3n)
for
layer_name
in
reversed
(
attn_layers
):
if
layer_name
in
self
.
shared_kv_cache_layers
:
self
.
kv_sharing_fast_prefill_eligible_layers
.
add
(
...
...
@@ -3095,11 +3109,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else
:
break
bind_kv_cache
(
kv_caches
,
self
.
compilation_config
.
static_forward_context
,
self
.
kv_caches
)
return
kv_caches
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Initialize KV cache based on `kv_cache_config`.
...
...
@@ -3111,6 +3120,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
kv_cache_config
=
kv_cache_config
self
.
may_reinitialize_input_batch
(
kv_cache_config
)
self
.
may_add_encoder_only_layers_to_kv_cache_config
()
self
.
maybe_add_kv_sharing_layers_to_kv_cache_groups
(
kv_cache_config
)
self
.
initialize_attn_backend
(
kv_cache_config
)
kv_caches
=
self
.
initialize_kv_cache_tensors
(
kv_cache_config
)
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
cb293f6a
...
...
@@ -55,9 +55,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.tpu_input_batch
import
CachedRequestState
,
InputBatch
from
.utils
import
(
MultiModalBudget
,
bind_kv_cache
,
initialize_kv_cache_for_kv_sharing
,
sanity_check_mm_encoder_outputs
)
from
.utils
import
(
MultiModalBudget
,
add_kv_sharing_layers_to_kv_cache_groups
,
bind_kv_cache
,
sanity_check_mm_encoder_outputs
)
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -1599,6 +1598,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
encoder_cache
.
clear
()
gc
.
collect
()
def
maybe_setup_cross_layer_kv_sharing
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
kv_cache_config
:
KVCacheConfig
,
)
->
None
:
"""
Add layers that re-use KV cache to KV cache group of its target layer.
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
"""
if
not
self
.
shared_kv_cache_layers
:
# No cross-layer KV sharing, return
return
add_kv_sharing_layers_to_kv_cache_groups
(
self
.
shared_kv_cache_layers
,
kv_cache_config
.
kv_cache_groups
,
)
for
layer_name
,
target_layer_name
in
self
.
shared_kv_cache_layers
.
items
(
):
logger
.
debug
(
"%s reuses KV cache of %s"
,
layer_name
,
target_layer_name
)
kv_caches
[
layer_name
]
=
kv_caches
[
target_layer_name
]
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Initialize KV cache based on `kv_cache_config`.
...
...
@@ -1664,14 +1687,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else
:
raise
NotImplementedError
# Setup `kv_cache_config` and `kv_caches` for models
# with cross-layer KV sharing
if
self
.
shared_kv_cache_layers
:
initialize_kv_cache_for_kv_sharing
(
self
.
shared_kv_cache_layers
,
kv_cache_config
.
kv_cache_groups
,
kv_caches
,
)
# Set up cross-layer KV cache sharing if needed
self
.
maybe_setup_cross_layer_kv_sharing
(
kv_caches
,
kv_cache_config
)
bind_kv_cache
(
kv_caches
,
...
...
vllm/v1/worker/utils.py
View file @
cb293f6a
...
...
@@ -203,12 +203,9 @@ def gather_mm_placeholders(
return
placeholders
[
is_embed
]
def
initialize_kv_cache_for_kv_sharing
(
def
add_kv_sharing_layers_to_kv_cache_groups
(
shared_kv_cache_layers
:
dict
[
str
,
str
],
kv_cache_groups
:
list
[
KVCacheGroupSpec
],
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
# Optional for now to avoid breaking TPU
attn_groups
:
Optional
[
list
[
list
[
AttentionGroup
]]]
=
None
,
runner_only_attn_layers
:
Optional
[
set
[
str
]]
=
None
,
)
->
None
:
"""
...
...
@@ -223,38 +220,15 @@ def initialize_kv_cache_for_kv_sharing(
means this layer will perform attention using the keys and values
from the KV cache of `shared_kv_cache_layers[layer_name]`.
kv_cache_groups: The KV cache groups of the model.
kv_caches: The allocated kv_caches with layer names as keys.
Note that layers in shared_kv_cache_layers.keys() are not
originally included as it only contains layers which have its own
KV cache allocation.
attn_groups: Optional list of attention groups. Layers in the same KV
cache group may be placed in different attention groups if they
have different attention backends. Currently only provided by
GPU model runner.
"""
# mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx)
layer_to_attn_group_idx
:
dict
[
str
,
tuple
[
int
,
int
]]
=
{}
if
attn_groups
:
for
kv_cache_group_idx
,
kv_attn_groups
in
enumerate
(
attn_groups
):
for
attn_group_idx
,
attn_group
in
enumerate
(
kv_attn_groups
):
for
layer_name
in
attn_group
.
layer_names
:
layer_to_attn_group_idx
[
layer_name
]
=
(
kv_cache_group_idx
,
attn_group_idx
)
else
:
for
kv_cache_group_idx
,
kv_cache_group
in
enumerate
(
kv_cache_groups
):
for
layer_name
in
kv_cache_group
.
layer_names
:
# attn group idx default to 0 if not provided
layer_to_attn_group_idx
[
layer_name
]
=
(
kv_cache_group_idx
,
0
)
layer_to_kv_cache_group
:
dict
[
str
,
KVCacheGroupSpec
]
=
{}
for
kv_cache_group
in
kv_cache_groups
:
for
layer_name
in
kv_cache_group
.
layer_names
:
layer_to_kv_cache_group
[
layer_name
]
=
kv_cache_group
for
layer_name
,
target_layer_name
in
shared_kv_cache_layers
.
items
():
kv_caches
[
layer_name
]
=
kv_caches
[
target_layer_name
]
kv_cache_group_idx
=
layer_to_attn_group_idx
[
target_layer_name
][
0
]
kv_cache_groups
[
kv_cache_group_idx
].
layer_names
.
append
(
layer_name
)
if
attn_groups
:
attn_group_idx
=
layer_to_attn_group_idx
[
target_layer_name
][
1
]
attn_groups
[
kv_cache_group_idx
][
attn_group_idx
].
layer_names
.
append
(
layer_name
)
tgt_kv_cache_group
=
layer_to_kv_cache_group
[
target_layer_name
]
tgt_kv_cache_group
.
layer_names
.
append
(
layer_name
)
if
runner_only_attn_layers
is
not
None
:
runner_only_attn_layers
.
add
(
layer_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment