Unverified Commit 9fec0e13 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] Cache attention metadata builds across hybrid KV-cache groups (#29627)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarStanislaw Wozniak <stw@zurich.ibm.com>
parent 254a7f8f
...@@ -172,7 +172,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): ...@@ -172,7 +172,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
) )
# Call the function # Call the function
result = make_local_attention_virtual_batches( result, _ = make_local_attention_virtual_batches(
attn_chunk_size, common_attn_metadata, block_size attn_chunk_size, common_attn_metadata, block_size
) )
......
...@@ -4,7 +4,7 @@ import functools ...@@ -4,7 +4,7 @@ import functools
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
...@@ -51,11 +51,19 @@ def create_chunked_local_attention_backend( ...@@ -51,11 +51,19 @@ def create_chunked_local_attention_backend(
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False, fast_build: bool = False,
) -> AttentionMetadata: ):
common_attn_metadata = make_local_attention_virtual_batches( cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
attention_chunk_size, common_attn_metadata, block_size attention_chunk_size, common_attn_metadata, block_size
) )
return super().build(common_prefix_len, common_attn_metadata, fast_build) metadata = super().build(common_prefix_len, cm, fast_build)
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
return metadata
def update_block_table(
self, metadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor
):
blk_table = metadata.make_virtual_batches_block_table(blk_table)
return super().update_block_table(metadata, blk_table, slot_mapping)
attn_backend = subclass_attention_backend( attn_backend = subclass_attention_backend(
name_prefix=prefix, name_prefix=prefix,
......
...@@ -207,7 +207,7 @@ if TYPE_CHECKING: ...@@ -207,7 +207,7 @@ if TYPE_CHECKING:
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False
VLLM_LOOPBACK_IP: str = "" VLLM_LOOPBACK_IP: str = ""
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True
VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: str | None = None VLLM_USE_TRTLLM_ATTENTION: str | None = None
VLLM_NVFP4_GEMM_BACKEND: str | None = None VLLM_NVFP4_GEMM_BACKEND: str | None = None
...@@ -1430,7 +1430,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1430,7 +1430,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# kv-cache memory usage and enable longer contexts) # kv-cache memory usage and enable longer contexts)
# TODO(lucas): Remove this flag once latency regression is resolved. # TODO(lucas): Remove this flag once latency regression is resolved.
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": lambda: bool( "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": lambda: bool(
int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0")) int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "1"))
), ),
# Enables support for the "store" option in the OpenAI Responses API. # Enables support for the "store" option in the OpenAI Responses API.
# When set to 1, vLLM's OpenAI server will retain the input and output # When set to 1, vLLM's OpenAI server will retain the input and output
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
import copy
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar from typing import ClassVar
...@@ -250,6 +251,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -250,6 +251,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
if get_flash_attn_version() == 3 if get_flash_attn_version() == 3
else AttentionCGSupport.UNIFORM_BATCH else AttentionCGSupport.UNIFORM_BATCH
) )
supports_update_block_table: bool = True
def __init__( def __init__(
self, self,
...@@ -493,6 +495,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -493,6 +495,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
) )
return attn_metadata return attn_metadata
def update_block_table(
self,
metadata: FlashAttentionMetadata,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> FlashAttentionMetadata:
new_metadata = copy.copy(metadata)
new_metadata.block_table = blk_table
new_metadata.slot_mapping = slot_mapping
return new_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs) return use_cascade_attention(*args, **kwargs)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import itertools import itertools
from dataclasses import dataclass from dataclasses import dataclass
...@@ -134,6 +135,8 @@ class Mamba2AttentionMetadata: ...@@ -134,6 +135,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder( class Mamba2AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata] BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
): ):
supports_update_block_table: bool = True
def __init__( def __init__(
self, self,
kv_cache_spec: AttentionSpec, kv_cache_spec: AttentionSpec,
...@@ -346,3 +349,27 @@ class Mamba2AttentionMetadataBuilder( ...@@ -346,3 +349,27 @@ class Mamba2AttentionMetadataBuilder(
num_computed_tokens_p=num_computed_tokens_p, num_computed_tokens_p=num_computed_tokens_p,
) )
return attn_metadata return attn_metadata
def update_block_table(
self,
metadata: Mamba2AttentionMetadata,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> Mamba2AttentionMetadata:
new_metadata = copy.copy(metadata)
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer
if (
metadata.num_prefills == 0
and num_reqs <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
state_indices_t = persistent_state_indices_t
new_metadata.state_indices_tensor = state_indices_t
return new_metadata
...@@ -4,6 +4,7 @@ import abc ...@@ -4,6 +4,7 @@ import abc
import enum import enum
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field, fields, make_dataclass from dataclasses import dataclass, field, fields, make_dataclass
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
...@@ -317,6 +318,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -317,6 +318,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# If not, set this to None. Otherwise set it to the query # If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch. # length that will be pulled into the front of the batch.
reorder_batch_threshold: int | None = None reorder_batch_threshold: int | None = None
# Does this backend/builder support updating the block table in existing
# metadata
supports_update_block_table: bool = False
@abstractmethod @abstractmethod
def __init__( def __init__(
...@@ -387,6 +391,21 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -387,6 +391,21 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
""" """
raise NotImplementedError raise NotImplementedError
def update_block_table(
self,
metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M:
"""
Update the block table for the attention metadata.
Faster when theres multiple kv-cache groups that create virtually the
same metadata but just with different block tables.
Only needs to be implemented if supports_update_block_table is True.
"""
raise NotImplementedError
def build_for_cudagraph_capture( def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata self, common_attn_metadata: CommonAttentionMetadata
) -> M: ) -> M:
...@@ -603,7 +622,7 @@ def make_local_attention_virtual_batches( ...@@ -603,7 +622,7 @@ def make_local_attention_virtual_batches(
attn_chunk_size: int, attn_chunk_size: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
block_size: int = 0, block_size: int = 0,
) -> CommonAttentionMetadata: ) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]:
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy() query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy() seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
block_table = common_attn_metadata.block_table_tensor block_table = common_attn_metadata.block_table_tensor
...@@ -715,9 +734,12 @@ def make_local_attention_virtual_batches( ...@@ -715,9 +734,12 @@ def make_local_attention_virtual_batches(
# tensor first, which recovers perf. # tensor first, which recovers perf.
batch_indices_torch = torch.from_numpy(batch_indices) batch_indices_torch = torch.from_numpy(batch_indices)
block_indices_torch = torch.from_numpy(block_indices) block_indices_torch = torch.from_numpy(block_indices)
block_table_local = block_table[batch_indices_torch, block_indices_torch].view(
virtual_batches, -1 # Save as a lambda so we can return this for update_block_table
) make_block_table = lambda block_table: block_table[
batch_indices_torch, block_indices_torch
].view(virtual_batches, -1)
block_table_local = make_block_table(block_table)
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
seq_lens_cpu = torch.from_numpy(seqlens_k_local) seq_lens_cpu = torch.from_numpy(seqlens_k_local)
...@@ -736,7 +758,7 @@ def make_local_attention_virtual_batches( ...@@ -736,7 +758,7 @@ def make_local_attention_virtual_batches(
causal=True, causal=True,
_seq_lens_cpu=seq_lens_cpu, _seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), _num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
) ), make_block_table
def make_kv_sharing_fast_prefill_common_attn_metadata( def make_kv_sharing_fast_prefill_common_attn_metadata(
......
...@@ -1630,6 +1630,15 @@ class GPUModelRunner( ...@@ -1630,6 +1630,15 @@ class GPUModelRunner(
logits_indices logits_indices
) )
# Cache attention metadata builds across hybrid KV-cache groups
# The only thing that changes between different hybrid KV-cache groups when the
# same metadata builder and KVCacheSpec is the same is the block table, so we
# can cache the attention metadata builds and just update the block table using
# `builder.update_block_table` if the builder supports it.
cached_attn_metadata: dict[
tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata
] = {}
def _build_attn_group_metadata( def _build_attn_group_metadata(
kv_cache_gid: int, kv_cache_gid: int,
attn_gid: int, attn_gid: int,
...@@ -1637,13 +1646,15 @@ class GPUModelRunner( ...@@ -1637,13 +1646,15 @@ class GPUModelRunner(
ubid: int | None = None, ubid: int | None = None,
) -> None: ) -> None:
attn_group = self.attn_groups[kv_cache_gid][attn_gid] attn_group = self.attn_groups[kv_cache_gid][attn_gid]
builder = attn_group.get_metadata_builder(ubid or 0)
cache_key = (kv_cache_groups[kv_cache_gid].kv_cache_spec, type(builder))
cascade_attn_prefix_len = ( cascade_attn_prefix_len = (
cascade_attn_prefix_lens[kv_cache_gid][attn_gid] cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
if cascade_attn_prefix_lens if cascade_attn_prefix_lens
else 0 else 0
) )
builder = attn_group.get_metadata_builder(ubid or 0)
extra_attn_metadata_args = {} extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
assert ubid is None, "UBatching not supported with GDN yet" assert ubid is None, "UBatching not supported with GDN yet"
...@@ -1658,12 +1669,23 @@ class GPUModelRunner( ...@@ -1658,12 +1669,23 @@ class GPUModelRunner(
attn_metadata_i = builder.build_for_cudagraph_capture( attn_metadata_i = builder.build_for_cudagraph_capture(
common_attn_metadata common_attn_metadata
) )
elif (
cache_key in cached_attn_metadata
and builder.supports_update_block_table
):
attn_metadata_i = builder.update_block_table(
cached_attn_metadata[cache_key],
common_attn_metadata.block_table_tensor,
common_attn_metadata.slot_mapping,
)
else: else:
attn_metadata_i = builder.build( attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len, common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args, **extra_attn_metadata_args,
) )
if builder.supports_update_block_table:
cached_attn_metadata[cache_key] = attn_metadata_i
if ubid is None: if ubid is None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment