Unverified Commit 1dc8a70b authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] Support multiple attention metadata builders per kv_cache_spec +...


[Attention] Support multiple attention metadata builders per kv_cache_spec  + proper local attention no hybrid kv cache fix (#21588)
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent f825c6bd
......@@ -313,7 +313,8 @@ def test_propose(num_speculative_tokens, backend):
# Mock runner for attention metadata building
proposer.runner = mock.MagicMock()
proposer.runner.attn_metadata_builders = [attn_metadata_builder]
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
......
......@@ -417,12 +417,12 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
return rnd_stride
# Patch the attention backend class and re-trigger the KV cache creation.
for attn_backend in model_runner.attn_backends:
for attn_group in model_runner._attn_group_iterator():
attn_backend = attn_group.backend
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
rnd_stride_order)
model_runner.attn_backends = []
model_runner.attn_metadata_builders = []
model_runner.attn_groups = []
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
# Shape is unchanged, but layout may differ
......
......@@ -106,6 +106,10 @@ class AttentionBackend(ABC):
block_size: int, num_seqs: int, num_queries: int) -> None:
raise NotImplementedError
@classmethod
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)
@dataclass
class AttentionMetadata:
......
......@@ -9,6 +9,7 @@ import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
......@@ -80,6 +81,7 @@ class Attention(nn.Module):
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
attn_backend: Optional[type[AttentionBackend]] = None,
**extra_impl_args,
) -> None:
"""
......@@ -137,15 +139,6 @@ class Attention(nn.Module):
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
# For v1 we have backend agnostic iRoPE (local chunked attention)
# we have to store the flag on the layer so gpu model runner can
# set KVSpec appropriately (and pop it so it doesnt get passed to
# the backends)
if envs.VLLM_USE_V1:
self.use_irope = extra_impl_args.pop("use_irope", False)
else:
self.use_irope = extra_impl_args.get("use_irope", False)
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None and not isinstance(
......@@ -166,18 +159,22 @@ class Attention(nn.Module):
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
is_attention_free,
use_mla=use_mla)
impl_cls = attn_backend.get_impl_cls()
if attn_backend is None:
self.attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
is_attention_free,
use_mla=use_mla)
else:
self.attn_backend = attn_backend
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(attn_backend.get_name())
self.backend = backend_name_to_enum(self.attn_backend.get_name())
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
......@@ -187,7 +184,7 @@ class Attention(nn.Module):
self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu()
self.use_output = attn_backend.accept_output_buffer
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
......@@ -309,6 +306,9 @@ class Attention(nn.Module):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import List, Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, QuantizationConfig
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_metadata_builder)
from ..layer import Attention
@functools.lru_cache
def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend,
attention_chunk_size: int,
block_size: int,
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
def build_preprocess_fn(cm: CommonAttentionMetadata):
return make_local_attention_virtual_batches(attention_chunk_size, cm,
block_size)
# Dynamically create a new attention backend that wraps the
# underlying attention backend but applies
# `make_local_attention_virtual_batches` before calling `build(...)`
builder_cls = subclass_attention_metadata_builder(
name_prefix=prefix,
builder_cls=underlying_attn_backend.get_builder_cls(),
build_preprocess_fn=build_preprocess_fn)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=builder_cls)
return attn_backend
class ChunkedLocalAttention(Attention):
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
attention_chunk_size: int,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
kv_sharing_target_layer_name: Optional[str] = None,
prefix: str = ""):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)
attn_backend = create_chunked_local_attention_backend(
underlying_attn_backend, attention_chunk_size, block_size)
else:
# in v0 the local attention is handled inside the backends
attn_backend = None
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
attn_backend=attn_backend)
......@@ -142,7 +142,7 @@ def get_attn_backend(
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_attention_free: bool = False,
use_mla: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
......
......@@ -25,6 +25,7 @@ from torch import nn
from transformers import Llama4TextConfig
from vllm.attention import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
......@@ -194,17 +195,18 @@ class Llama4Attention(nn.Module):
is_neox_style=is_neox_style,
) if not self.nope else None
self.attn = Attention(
attn_cls = Attention if self.nope else ChunkedLocalAttention
self.attn = attn_cls(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=None,
use_irope=not self.nope,
prefix=f"{prefix}.attn",
)
**({
"attention_chunk_size": config.attention_chunk_size
} if not self.nope else {}))
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
......
......@@ -5,12 +5,12 @@ import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional,
TypeVar)
import numpy as np
import torch
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils import cdiv
......@@ -20,6 +20,8 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
......@@ -532,6 +534,48 @@ def make_local_attention_virtual_batches(
)
def subclass_attention_metadata_builder(
name_prefix: str,
builder_cls: type[AttentionMetadataBuilder[M]],
build_preprocess_fn: Callable[[CommonAttentionMetadata],
CommonAttentionMetadata],
) -> type[AttentionMetadataBuilder[M]]:
"""
Return a new subclass of `builder_cls` whose .build(...) method
first calls build_preprocess_fn(common_attn_metadata) on the metadata.
"""
name: str = name_prefix + builder_cls.__name__ # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False):
return builder_cls.build(self, common_prefix_len,
build_preprocess_fn(common_attn_metadata),
fast_build)
Wrapped = type(
name,
(builder_cls, ), # inherit from the original
{
"build": build,
})
return Wrapped # type: ignore
def subclass_attention_backend(
name_prefix: str, attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]]
) -> type[AttentionBackend]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(name, (attention_backend_cls, ),
{"get_builder_cls": lambda: builder_cls})
def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
......
......@@ -158,9 +158,9 @@ class EagleProposer:
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[
0].build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
......@@ -349,7 +349,8 @@ class EagleProposer:
hidden_states: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
) -> list[torch.Tensor]:
tree_attn_metadata_builder = self.runner.attn_metadata_builders[0]
tree_attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builder
assert isinstance(tree_attn_metadata_builder,
TreeAttentionMetadataBuilder)
......
......@@ -53,11 +53,11 @@ class CPUModelRunner(GPUModelRunner):
raise ValueError("Multiple KVCacheGroups is not"
"currently supported with CPU model runner.")
assert type(
self.attn_metadata_builders[0]) is TorchSDPAMetadataBuilderV1
assert type(self.attn_groups[0]
[0].metadata_builder) is TorchSDPAMetadataBuilderV1
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
scheduler_output)
self.attn_groups[0][0].metadata_builder.reorder_batch(
self.input_batch, scheduler_output)
def _postprocess_tenosrs(self) -> None:
# Note: replace device tensors with cpu tensors
......
This diff is collapsed.
......@@ -15,8 +15,9 @@ import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import (ParallelConfig, VllmConfig,
get_layers_from_vllm_config, update_config)
......@@ -518,7 +519,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue
if attn_module.attn_type == AttentionType.DECODER:
if attn_module.use_irope:
if isinstance(attn_module, ChunkedLocalAttention):
logger.warning_once(
"Using irope in Pallas is not supported yet, it "
"will fall back to global attention for long context.")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import ModelConfig, SchedulerConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.registry import MultiModalRegistry
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
......@@ -122,6 +125,13 @@ class MultiModalBudget:
return max_items_per_prompt, max_items_per_batch
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
metadata_builder: AttentionMetadataBuilder
layer_names: list[str]
def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings,
expected_num_items: int,
......@@ -196,6 +206,8 @@ def initialize_kv_cache_for_kv_sharing(
shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec],
kv_caches: dict[str, torch.Tensor],
# Optional for now to avoid breaking TPU
attn_groups: Optional[list[list[AttentionGroup]]] = None,
) -> None:
"""
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
......@@ -225,6 +237,15 @@ def initialize_kv_cache_for_kv_sharing(
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
kv_cache_groups[group_idx].layer_names.append(layer_name)
if attn_groups is not None:
assert len(attn_groups[group_idx]) == 1, (
"Only one attention group per KV cache group is supported "
"for KV-cache sharing for now.")
# TODO(lucas): I think in the future the layers that re-use a
# KV cache will be in a different attention group so we can
# remove this code from here.
attn_groups[group_idx][0].layer_names.append(layer_name)
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment