Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bdf13965
Unverified
Commit
bdf13965
authored
Jun 03, 2025
by
Yong Hoon Shin
Committed by
GitHub
Jun 03, 2025
Browse files
[V1] Support cross-layer KV sharing (#18212)
Signed-off-by:
Yong Hoon Shin
<
yhshin@meta.com
>
parent
fa98d777
Changes
31
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
191 additions
and
45 deletions
+191
-45
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+21
-15
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+4
-0
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+2
-1
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+2
-1
vllm/v1/attention/backends/mla/triton_mla.py
vllm/v1/attention/backends/mla/triton_mla.py
+2
-1
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+5
-1
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+28
-23
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+33
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+29
-2
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+29
-1
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+36
-0
No files found.
vllm/v1/attention/backends/flashinfer.py
View file @
bdf13965
...
@@ -507,6 +507,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -507,6 +507,7 @@ class FlashInferImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
...
@@ -521,6 +522,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -521,6 +522,7 @@ class FlashInferImpl(AttentionImpl):
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
logits_soft_cap
=
logits_soft_cap
self
.
logits_soft_cap
=
logits_soft_cap
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
@@ -568,11 +570,15 @@ class FlashInferImpl(AttentionImpl):
...
@@ -568,11 +570,15 @@ class FlashInferImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
# performance to make sure it does not introduce any overhead.
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens] and
# not padded. However, we don't need to do key[:num_actual_tokens]
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# the slot_mapping's shape to determine the number of actual tokens.
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
key
,
value
,
value
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
bdf13965
...
@@ -586,6 +586,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -586,6 +586,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
# MLA Specific Arguments
# MLA Specific Arguments
q_lora_rank
:
Optional
[
int
],
q_lora_rank
:
Optional
[
int
],
kv_lora_rank
:
int
,
kv_lora_rank
:
int
,
...
@@ -595,6 +596,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -595,6 +596,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
v_head_dim
:
int
,
v_head_dim
:
int
,
kv_b_proj
:
ColumnParallelLinear
,
kv_b_proj
:
ColumnParallelLinear
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported for MLA"
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
bdf13965
...
@@ -93,12 +93,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -93,12 +93,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
# MLA Specific Arguments
# MLA Specific Arguments
**
mla_args
)
->
None
:
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
mla_args
)
kv_sharing_target_layer_name
,
**
mla_args
)
assert
is_flashmla_supported
(),
\
assert
is_flashmla_supported
(),
\
"FlashMLA is not supported on this device"
"FlashMLA is not supported on this device"
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
bdf13965
...
@@ -139,12 +139,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
...
@@ -139,12 +139,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
# MLA Specific Arguments
# MLA Specific Arguments
**
mla_args
)
->
None
:
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
mla_args
)
kv_sharing_target_layer_name
,
**
mla_args
)
assert
(
num_heads
==
16
or
num_heads
==
128
),
(
assert
(
num_heads
==
16
or
num_heads
==
128
),
(
f
"Aiter MLA only supports 16 or 128 number of heads.
\n
"
f
"Aiter MLA only supports 16 or 128 number of heads.
\n
"
f
"Provided
{
num_heads
}
number of heads.
\n
"
f
"Provided
{
num_heads
}
number of heads.
\n
"
...
...
vllm/v1/attention/backends/mla/triton_mla.py
View file @
bdf13965
...
@@ -41,12 +41,13 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -41,12 +41,13 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
# MLA Specific Arguments
# MLA Specific Arguments
**
mla_args
)
->
None
:
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
mla_args
)
kv_sharing_target_layer_name
,
**
mla_args
)
unsupported_features
=
[
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
...
...
vllm/v1/attention/backends/pallas.py
View file @
bdf13965
...
@@ -113,6 +113,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -113,6 +113,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
int
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
use_irope
:
if
use_irope
:
...
@@ -128,6 +129,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -128,6 +129,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
self
.
num_kv_heads
=
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
self
.
sliding_window
=
sliding_window
self
.
sliding_window
=
sliding_window
self
.
logits_soft_cap
=
logits_soft_cap
self
.
logits_soft_cap
=
logits_soft_cap
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
@@ -181,7 +183,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -181,7 +183,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
)
if
kv_cache
.
numel
()
>
0
:
if
self
.
kv_sharing_target_layer_name
is
None
and
kv_cache
.
numel
()
>
0
:
# Write input keys and values to the KV cache.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping
=
attn_metadata
.
slot_mapping
slot_mapping
=
attn_metadata
.
slot_mapping
write_to_kv_cache
(
key
,
value
,
kv_cache
,
slot_mapping
)
write_to_kv_cache
(
key
,
value
,
kv_cache
,
slot_mapping
)
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
bdf13965
...
@@ -88,6 +88,7 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -88,6 +88,7 @@ class TritonAttentionImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
int
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
...
@@ -109,6 +110,7 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -109,6 +110,7 @@ class TritonAttentionImpl(AttentionImpl):
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap
=
0
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
self
.
logits_soft_cap
=
logits_soft_cap
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
self
.
use_irope
=
use_irope
self
.
use_irope
=
use_irope
...
@@ -178,8 +180,13 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -178,8 +180,13 @@ class TritonAttentionImpl(AttentionImpl):
if
use_prefill_decode_attn
:
if
use_prefill_decode_attn
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
else
:
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if
use_prefill_decode_attn
:
PagedAttention
.
write_to_paged_cache
(
PagedAttention
.
write_to_paged_cache
(
key
,
key
,
value
,
value
,
...
@@ -190,9 +197,7 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -190,9 +197,7 @@ class TritonAttentionImpl(AttentionImpl):
layer
.
_k_scale
,
layer
.
_k_scale
,
layer
.
_v_scale
,
layer
.
_v_scale
,
)
)
else
:
else
:
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
key
,
value
,
value
,
...
...
vllm/v1/attention/backends/utils.py
View file @
bdf13965
...
@@ -17,3 +17,36 @@ class CommonAttentionMetadata:
...
@@ -17,3 +17,36 @@ class CommonAttentionMetadata:
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
"""(batch_size,), the length of each request including both computed tokens
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
and newly scheduled tokens"""
def
validate_kv_sharing_target
(
current_layer_name
,
target_layer_name
,
static_forward_context
):
error_msg
=
(
f
"Specified KV sharing target layer for
{
current_layer_name
}
"
f
"is not valid: target layer
{
target_layer_name
}
"
)
if
current_layer_name
==
target_layer_name
:
raise
ValueError
(
error_msg
+
"cannot be the same as the current layer."
)
if
target_layer_name
not
in
static_forward_context
:
from
vllm.model_executor.models.utils
import
extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx
=
extract_layer_index
(
current_layer_name
)
target_layer_idx
=
extract_layer_index
(
target_layer_name
)
if
current_layer_idx
<=
target_layer_idx
:
raise
ValueError
(
error_msg
+
"must come before the current layer."
)
else
:
raise
ValueError
(
error_msg
+
"is not a valid Attention layer in the model."
)
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type
=
static_forward_context
[
target_layer_name
].
attn_type
expected
=
static_forward_context
[
current_layer_name
].
attn_type
if
target_layer_attn_type
!=
expected
:
raise
ValueError
(
error_msg
+
f
"must be the same type as the current layer (
{
expected
}
)."
)
vllm/v1/worker/gpu_model_runner.py
View file @
bdf13965
...
@@ -59,8 +59,8 @@ from vllm.v1.worker.block_table import BlockTable
...
@@ -59,8 +59,8 @@ from vllm.v1.worker.block_table import BlockTable
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
.utils
import
(
gather_mm_placeholders
,
sa
nit
y_check_mm_encoder_outputs
,
from
.utils
import
(
gather_mm_placeholders
,
i
nit
ialize_kv_cache_for_kv_sharing
,
scatter_mm_placeholders
)
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
import
xgrammar
as
xgr
...
@@ -276,6 +276,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -276,6 +276,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory
=
self
.
pin_memory
)
pin_memory
=
self
.
pin_memory
)
self
.
seq_lens_np
=
self
.
seq_lens_cpu
.
numpy
()
self
.
seq_lens_np
=
self
.
seq_lens_cpu
.
numpy
()
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
"""
"""
Update the order of requests in the batch based on the attention
Update the order of requests in the batch based on the attention
...
@@ -2097,6 +2103,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2097,6 +2103,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV cache specs.
# KV cache specs.
raise
ValueError
(
"Unknown KV cache spec type."
)
raise
ValueError
(
"Unknown KV cache spec type."
)
# Setup `kv_cache_config` and `kv_caches` for models
# with cross-layer KV sharing
if
self
.
shared_kv_cache_layers
:
initialize_kv_cache_for_kv_sharing
(
self
.
shared_kv_cache_layers
,
kv_cache_config
.
kv_cache_groups
,
kv_caches
,
)
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
# validate all draft model layers belong to the same kv cache
# validate all draft model layers belong to the same kv cache
...
@@ -2125,6 +2140,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2125,6 +2140,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
layers
.
items
():
for
layer_name
,
attn_module
in
layers
.
items
():
if
(
kv_tgt_layer
:
=
attn_module
.
kv_sharing_target_layer_name
)
is
not
None
:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self
.
shared_kv_cache_layers
[
layer_name
]
=
kv_tgt_layer
continue
# TODO: Support other attention modules, e.g., cross-attention
# TODO: Support other attention modules, e.g., cross-attention
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
sliding_window
is
not
None
:
if
attn_module
.
sliding_window
is
not
None
:
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
bdf13965
...
@@ -44,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache
...
@@ -44,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
.utils
import
sanity_check_mm_encoder_outputs
from
.utils
import
(
initialize_kv_cache_for_kv_sharing
,
sanity_check_mm_encoder_outputs
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
@@ -238,6 +239,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -238,6 +239,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
num_reqs_paddings
=
_get_req_paddings
(
self
.
num_reqs_paddings
=
_get_req_paddings
(
min_req_size
=
MIN_NUM_SEQS
,
max_req_size
=
self
.
max_num_reqs
)
min_req_size
=
MIN_NUM_SEQS
,
max_req_size
=
self
.
max_num_reqs
)
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self
.
shared_kv_cache_layers
:
dict
[
str
,
str
]
=
{}
# tensors for structured decoding
# tensors for structured decoding
self
.
grammar_bitmask_cpu
=
torch
.
zeros
(
self
.
grammar_bitmask_cpu
=
torch
.
zeros
(
(
self
.
max_num_reqs
,
cdiv
(
self
.
vocab_size
,
32
)),
(
self
.
max_num_reqs
,
cdiv
(
self
.
vocab_size
,
32
)),
...
@@ -455,6 +462,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -455,6 +462,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
layers
.
items
():
for
layer_name
,
attn_module
in
layers
.
items
():
if
(
kv_tgt_layer
:
=
attn_module
.
kv_sharing_target_layer_name
)
is
not
None
:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self
.
shared_kv_cache_layers
[
layer_name
]
=
kv_tgt_layer
continue
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
sliding_window
is
not
None
:
if
attn_module
.
sliding_window
is
not
None
:
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
...
@@ -1376,6 +1395,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1376,6 +1395,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
# Setup `kv_cache_config` and `kv_caches` for models
# with cross-layer KV sharing
if
self
.
shared_kv_cache_layers
:
initialize_kv_cache_for_kv_sharing
(
self
.
shared_kv_cache_layers
,
kv_cache_config
.
kv_cache_groups
,
kv_caches
,
)
bind_kv_cache
(
bind_kv_cache
(
kv_caches
,
kv_caches
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
...
...
vllm/v1/worker/utils.py
View file @
bdf13965
...
@@ -4,6 +4,8 @@ from typing import Optional
...
@@ -4,6 +4,8 @@ from typing import Optional
import
torch
import
torch
from
vllm.v1.kv_cache_interface
import
KVCacheGroupSpec
def
sanity_check_mm_encoder_outputs
(
def
sanity_check_mm_encoder_outputs
(
mm_embeddings
:
object
,
mm_embeddings
:
object
,
...
@@ -73,3 +75,37 @@ def gather_mm_placeholders(
...
@@ -73,3 +75,37 @@ def gather_mm_placeholders(
return
placeholders
return
placeholders
return
placeholders
[
is_embed
]
return
placeholders
[
is_embed
]
def
initialize_kv_cache_for_kv_sharing
(
shared_kv_cache_layers
:
dict
[
str
,
str
],
kv_cache_groups
:
list
[
KVCacheGroupSpec
],
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
)
->
None
:
"""
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
for layers that do not allocate its own KV cache, based on the mapping in
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
group, which is needed to ensure that attention metadata is assigned later.
Args:
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
If an Attention layer `layer_name` is in the keys of this dict, it
means this layer will perform attention using the keys and values
from the KV cache of `shared_kv_cache_layers[layer_name]`.
kv_cache_groups: The KV cache groups of the model.
kv_caches: The allocated kv_caches with layer names as keys.
Note that layers in shared_kv_cache_layers.keys() are not
originally included as it only contains layers which have its own
KV cache allocation.
"""
# Record index of KV cache group for each layer that allocates a KV cache.
layer_to_kv_cache_group_idx
:
dict
[
str
,
int
]
=
{}
for
i
,
kv_cache_group
in
enumerate
(
kv_cache_groups
):
for
layer_name
in
kv_cache_group
.
layer_names
:
layer_to_kv_cache_group_idx
[
layer_name
]
=
i
for
layer_name
,
target_layer_name
in
shared_kv_cache_layers
.
items
():
kv_caches
[
layer_name
]
=
kv_caches
[
target_layer_name
]
group_idx
=
layer_to_kv_cache_group_idx
[
target_layer_name
]
kv_cache_groups
[
group_idx
].
layer_names
.
append
(
layer_name
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment