Unverified Commit bdf13965 authored by Yong Hoon Shin's avatar Yong Hoon Shin Committed by GitHub
Browse files

[V1] Support cross-layer KV sharing (#18212)


Signed-off-by: default avatarYong Hoon Shin <yhshin@meta.com>
parent fa98d777
...@@ -507,6 +507,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -507,6 +507,7 @@ class FlashInferImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None, blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
...@@ -521,6 +522,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -521,6 +522,7 @@ class FlashInferImpl(AttentionImpl):
self.sliding_window = (sliding_window - 1, 0) self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
...@@ -568,21 +570,25 @@ class FlashInferImpl(AttentionImpl): ...@@ -568,21 +570,25 @@ class FlashInferImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead. # performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
# Reshape the input keys and values and store them in the cache.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is if self.kv_sharing_target_layer_name is None:
# not padded. However, we don't need to do key[:num_actual_tokens] and # Reshape the input keys and values and store them in the cache.
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses # Skip this if sharing KV cache with an earlier attention layer.
# the slot_mapping's shape to determine the number of actual tokens. # NOTE(woosuk): Here, key and value are padded while slot_mapping is
torch.ops._C_cache_ops.reshape_and_cache_flash( # not padded. However, we don't need to do key[:num_actual_tokens]
key, # and value[:num_actual_tokens] because the reshape_and_cache_flash
value, # op uses the slot_mapping's shape to determine the number of
kv_cache[:, 0], # actual tokens.
kv_cache[:, 1], torch.ops._C_cache_ops.reshape_and_cache_flash(
attn_metadata.slot_mapping, key,
self.kv_cache_dtype, value,
layer._k_scale, kv_cache[:, 0],
layer._v_scale, kv_cache[:, 1],
) attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
window_left = (self.sliding_window[0] window_left = (self.sliding_window[0]
if self.sliding_window is not None else -1) if self.sliding_window is not None else -1)
......
...@@ -586,6 +586,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -586,6 +586,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
blocksparse_params: Optional[dict[str, Any]], blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float], logits_soft_cap: Optional[float],
attn_type: str, attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments # MLA Specific Arguments
q_lora_rank: Optional[int], q_lora_rank: Optional[int],
kv_lora_rank: int, kv_lora_rank: int,
...@@ -595,6 +596,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -595,6 +596,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
v_head_dim: int, v_head_dim: int,
kv_b_proj: ColumnParallelLinear, kv_b_proj: ColumnParallelLinear,
) -> None: ) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported for MLA")
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
......
...@@ -93,12 +93,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -93,12 +93,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
blocksparse_params: Optional[dict[str, Any]], blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float], logits_soft_cap: Optional[float],
attn_type: str, attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments # MLA Specific Arguments
**mla_args) -> None: **mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads, super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype, alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type, blocksparse_params, logits_soft_cap, attn_type,
**mla_args) kv_sharing_target_layer_name, **mla_args)
assert is_flashmla_supported(), \ assert is_flashmla_supported(), \
"FlashMLA is not supported on this device" "FlashMLA is not supported on this device"
......
...@@ -139,12 +139,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -139,12 +139,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
blocksparse_params: Optional[dict[str, Any]], blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float], logits_soft_cap: Optional[float],
attn_type: str, attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments # MLA Specific Arguments
**mla_args) -> None: **mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads, super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype, alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type, blocksparse_params, logits_soft_cap, attn_type,
**mla_args) kv_sharing_target_layer_name, **mla_args)
assert (num_heads == 16 or num_heads == 128), ( assert (num_heads == 16 or num_heads == 128), (
f"Aiter MLA only supports 16 or 128 number of heads.\n" f"Aiter MLA only supports 16 or 128 number of heads.\n"
f"Provided {num_heads} number of heads.\n" f"Provided {num_heads} number of heads.\n"
......
...@@ -41,12 +41,13 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -41,12 +41,13 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
blocksparse_params: Optional[dict[str, Any]], blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float], logits_soft_cap: Optional[float],
attn_type: str, attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments # MLA Specific Arguments
**mla_args) -> None: **mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads, super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype, alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type, blocksparse_params, logits_soft_cap, attn_type,
**mla_args) kv_sharing_target_layer_name, **mla_args)
unsupported_features = [ unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
......
...@@ -113,6 +113,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -113,6 +113,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None, blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False, use_irope: bool = False,
) -> None: ) -> None:
if use_irope: if use_irope:
...@@ -128,6 +129,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -128,6 +129,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.logits_soft_cap = logits_soft_cap self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
...@@ -181,7 +183,9 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -181,7 +183,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size) query = query.view(num_tokens, self.num_heads, self.head_size)
if kv_cache.numel() > 0: if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
# Write input keys and values to the KV cache.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(key, value, kv_cache, slot_mapping) write_to_kv_cache(key, value, kv_cache, slot_mapping)
......
...@@ -88,6 +88,7 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -88,6 +88,7 @@ class TritonAttentionImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None, blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False, use_irope: bool = False,
) -> None: ) -> None:
if blocksparse_params is not None: if blocksparse_params is not None:
...@@ -109,6 +110,7 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -109,6 +110,7 @@ class TritonAttentionImpl(AttentionImpl):
# In flash-attn, setting logits_soft_cap as 0 means no soft cap. # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0 logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.use_irope = use_irope self.use_irope = use_irope
...@@ -178,31 +180,34 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -178,31 +180,34 @@ class TritonAttentionImpl(AttentionImpl):
if use_prefill_decode_attn: if use_prefill_decode_attn:
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else: else:
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key, if self.kv_sharing_target_layer_name is None:
value, # Reshape the input keys and values and store them in the cache.
key_cache, # Skip this if sharing KV cache with an earlier attention layer.
value_cache, if use_prefill_decode_attn:
attn_metadata.slot_mapping, PagedAttention.write_to_paged_cache(
self.kv_cache_dtype, key,
layer._k_scale, value,
layer._v_scale, key_cache,
) value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
......
...@@ -17,3 +17,36 @@ class CommonAttentionMetadata: ...@@ -17,3 +17,36 @@ class CommonAttentionMetadata:
seq_lens: torch.Tensor seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens """(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens""" and newly scheduled tokens"""
def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context):
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
f"is not valid: target layer {target_layer_name} ")
if current_layer_name == target_layer_name:
raise ValueError(error_msg +
"cannot be the same as the current layer.")
if target_layer_name not in static_forward_context:
from vllm.model_executor.models.utils import extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx = extract_layer_index(current_layer_name)
target_layer_idx = extract_layer_index(target_layer_name)
if current_layer_idx <= target_layer_idx:
raise ValueError(error_msg + "must come before the current layer.")
else:
raise ValueError(error_msg +
"is not a valid Attention layer in the model.")
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type = static_forward_context[
target_layer_name].attn_type
expected = static_forward_context[current_layer_name].attn_type
if target_layer_attn_type != expected:
raise ValueError(
error_msg +
f"must be the same type as the current layer ({expected}).")
...@@ -59,8 +59,8 @@ from vllm.v1.worker.block_table import BlockTable ...@@ -59,8 +59,8 @@ from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
scatter_mm_placeholders) sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr import xgrammar as xgr
...@@ -276,6 +276,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -276,6 +276,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy() self.seq_lens_np = self.seq_lens_cpu.numpy()
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {}
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
""" """
Update the order of requests in the batch based on the attention Update the order of requests in the batch based on the attention
...@@ -2097,6 +2103,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2097,6 +2103,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV cache specs. # KV cache specs.
raise ValueError("Unknown KV cache spec type.") raise ValueError("Unknown KV cache spec type.")
# Setup `kv_cache_config` and `kv_caches` for models
# with cross-layer KV sharing
if self.shared_kv_cache_layers:
initialize_kv_cache_for_kv_sharing(
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
kv_caches,
)
if self.speculative_config and self.speculative_config.use_eagle(): if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
# validate all draft model layers belong to the same kv cache # validate all draft model layers belong to the same kv cache
...@@ -2125,6 +2140,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2125,6 +2140,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_mla = self.vllm_config.model_config.use_mla use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in layers.items(): for layer_name, attn_module in layers.items():
if (kv_tgt_layer :=
attn_module.kv_sharing_target_layer_name) is not None:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
# TODO: Support other attention modules, e.g., cross-attention # TODO: Support other attention modules, e.g., cross-attention
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None: if attn_module.sliding_window is not None:
......
...@@ -44,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache ...@@ -44,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import sanity_check_mm_encoder_outputs from .utils import (initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
...@@ -238,6 +239,12 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -238,6 +239,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.num_reqs_paddings = _get_req_paddings( self.num_reqs_paddings = _get_req_paddings(
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs) min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {}
# tensors for structured decoding # tensors for structured decoding
self.grammar_bitmask_cpu = torch.zeros( self.grammar_bitmask_cpu = torch.zeros(
(self.max_num_reqs, cdiv(self.vocab_size, 32)), (self.max_num_reqs, cdiv(self.vocab_size, 32)),
...@@ -455,6 +462,18 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -455,6 +462,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in layers.items(): for layer_name, attn_module in layers.items():
if (kv_tgt_layer :=
attn_module.kv_sharing_target_layer_name) is not None:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None: if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec( kv_cache_spec[layer_name] = SlidingWindowSpec(
...@@ -1376,6 +1395,15 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -1376,6 +1395,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
else: else:
raise NotImplementedError raise NotImplementedError
# Setup `kv_cache_config` and `kv_caches` for models
# with cross-layer KV sharing
if self.shared_kv_cache_layers:
initialize_kv_cache_for_kv_sharing(
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
kv_caches,
)
bind_kv_cache( bind_kv_cache(
kv_caches, kv_caches,
self.vllm_config.compilation_config.static_forward_context, self.vllm_config.compilation_config.static_forward_context,
......
...@@ -4,6 +4,8 @@ from typing import Optional ...@@ -4,6 +4,8 @@ from typing import Optional
import torch import torch
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
def sanity_check_mm_encoder_outputs( def sanity_check_mm_encoder_outputs(
mm_embeddings: object, mm_embeddings: object,
...@@ -73,3 +75,37 @@ def gather_mm_placeholders( ...@@ -73,3 +75,37 @@ def gather_mm_placeholders(
return placeholders return placeholders
return placeholders[is_embed] return placeholders[is_embed]
def initialize_kv_cache_for_kv_sharing(
shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec],
kv_caches: dict[str, torch.Tensor],
) -> None:
"""
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
for layers that do not allocate its own KV cache, based on the mapping in
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
group, which is needed to ensure that attention metadata is assigned later.
Args:
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
If an Attention layer `layer_name` is in the keys of this dict, it
means this layer will perform attention using the keys and values
from the KV cache of `shared_kv_cache_layers[layer_name]`.
kv_cache_groups: The KV cache groups of the model.
kv_caches: The allocated kv_caches with layer names as keys.
Note that layers in shared_kv_cache_layers.keys() are not
originally included as it only contains layers which have its own
KV cache allocation.
"""
# Record index of KV cache group for each layer that allocates a KV cache.
layer_to_kv_cache_group_idx: dict[str, int] = {}
for i, kv_cache_group in enumerate(kv_cache_groups):
for layer_name in kv_cache_group.layer_names:
layer_to_kv_cache_group_idx[layer_name] = i
for layer_name, target_layer_name in shared_kv_cache_layers.items():
kv_caches[layer_name] = kv_caches[target_layer_name]
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
kv_cache_groups[group_idx].layer_names.append(layer_name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment