Unverified Commit 20228cb8 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[3/N][Attention] Move AttentionMetadata-related code from utils.py to backend.py (#32054)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 7c0d3c51
...@@ -217,12 +217,12 @@ from vllm.v1.attention.backend import ( ...@@ -217,12 +217,12 @@ from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionLayer, AttentionLayer,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MLAAttentionImpl, MLAAttentionImpl,
) )
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_dcp_local_seq_lens, get_dcp_local_seq_lens,
get_per_layer_parameters, get_per_layer_parameters,
infer_global_hyperparameters, infer_global_hyperparameters,
......
...@@ -11,6 +11,7 @@ from vllm.config.cache import CacheDType ...@@ -11,6 +11,7 @@ from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
...@@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import ( ...@@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonMetadata, MLACommonMetadata,
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.batch_invariant import (
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
...@@ -31,7 +32,6 @@ from vllm.v1.attention.backends.mla.common import ( ...@@ -31,7 +32,6 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
QueryLenSupport, QueryLenSupport,
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined] from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
flash_attn_varlen_func, flash_attn_varlen_func,
......
...@@ -10,6 +10,7 @@ from vllm.config.cache import CacheDType ...@@ -10,6 +10,7 @@ from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
...@@ -21,7 +22,7 @@ from vllm.v1.attention.backends.mla.common import ( ...@@ -21,7 +22,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
QueryLenSupport, QueryLenSupport,
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType from vllm.v1.attention.backends.utils import KVCacheLayoutType
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -13,7 +13,12 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -13,7 +13,12 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonDecodeMetadata, MLACommonDecodeMetadata,
...@@ -23,7 +28,6 @@ from vllm.v1.attention.backends.mla.common import ( ...@@ -23,7 +28,6 @@ from vllm.v1.attention.backends.mla.common import (
QueryLenSupport, QueryLenSupport,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
reshape_attn_output_for_spec_decode, reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode, reshape_query_for_spec_decode,
) )
......
...@@ -16,15 +16,15 @@ from vllm.triton_utils import tl, triton ...@@ -16,15 +16,15 @@ from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport,
AttentionLayer, AttentionLayer,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf, MultipleOf,
) )
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
reshape_attn_output_for_spec_decode, reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode, reshape_query_for_spec_decode,
split_decodes_and_prefills, split_decodes_and_prefills,
......
...@@ -11,12 +11,12 @@ from vllm.platforms import current_platform ...@@ -11,12 +11,12 @@ from vllm.platforms import current_platform
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills, split_decodes_and_prefills,
split_prefill_chunks, split_prefill_chunks,
) )
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionLayer, MultipleOf from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonDecodeMetadata, MLACommonDecodeMetadata,
...@@ -17,7 +17,6 @@ from vllm.v1.attention.backends.mla.common import ( ...@@ -17,7 +17,6 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
QueryLenSupport, QueryLenSupport,
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
......
...@@ -13,18 +13,16 @@ from vllm.config import VllmConfig ...@@ -13,18 +13,16 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport,
AttentionLayer, AttentionLayer,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
) )
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.mla.flashmla_sparse import ( from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index, triton_convert_req_index_to_global_index,
) )
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING: if TYPE_CHECKING:
......
...@@ -15,14 +15,14 @@ from vllm.utils.math_utils import cdiv ...@@ -15,14 +15,14 @@ from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import get_cu_count from vllm.utils.platform_utils import get_cu_count
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport,
AttentionImpl, AttentionImpl,
AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata,
MultipleOf, MultipleOf,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_prefills_and_extends, split_decodes_prefills_and_extends,
) )
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
......
...@@ -16,16 +16,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -16,16 +16,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport,
AttentionImpl, AttentionImpl,
AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata,
MultipleOf, MultipleOf,
) )
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.ops.chunked_prefill_paged_decode import ( from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode, chunked_prefill_paged_decode,
) )
......
...@@ -14,12 +14,12 @@ from vllm.logger import init_logger ...@@ -14,12 +14,12 @@ from vllm.logger import init_logger
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata,
MultipleOf, MultipleOf,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills, split_decodes_and_prefills,
) )
from vllm.v1.attention.ops.triton_unified_attention import unified_attention from vllm.v1.attention.ops.triton_unified_attention import unified_attention
......
...@@ -19,14 +19,12 @@ from vllm.platforms.interface import DeviceCapability ...@@ -19,14 +19,12 @@ from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import next_power_of_2 from vllm.utils.math_utils import next_power_of_2
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionImpl,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
MultipleOf,
) )
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import enum
import functools import functools
from abc import abstractmethod
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass, field, fields, make_dataclass from dataclasses import dataclass, field, fields, make_dataclass
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
ClassVar,
Generic,
Literal, Literal,
Protocol, Protocol,
TypeVar, TypeVar,
...@@ -19,7 +14,7 @@ from typing import ( ...@@ -19,7 +14,7 @@ from typing import (
import numpy as np import numpy as np
import torch import torch
from typing_extensions import deprecated, runtime_checkable from typing_extensions import runtime_checkable
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
...@@ -38,8 +33,9 @@ from vllm.v1.attention.backend import ( ...@@ -38,8 +33,9 @@ from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
) )
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.ubatch_utils import UBatchSlice from vllm.v1.worker.ubatch_utils import UBatchSlice
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -53,123 +49,6 @@ def is_valid_kv_cache_layout(value: str) -> bool: ...@@ -53,123 +49,6 @@ def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType) return value in get_args(KVCacheLayoutType)
@dataclass
class CommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
"""
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""
num_reqs: int
"""Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
max_seq_len: int
"""Longest context length (may be an upper bound)"""
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
causal: bool = True
# Needed by FastPrefillAttentionBuilder
logits_indices_padded: torch.Tensor | None = None
num_logits_indices: int | None = None
# Needed by CrossAttentionBuilder
encoder_seq_lens: torch.Tensor | None = None
encoder_seq_lens_cpu: np.ndarray | None = None
dcp_local_seq_lens: torch.Tensor | None = None
dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
_seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None
_num_computed_tokens_cache: torch.Tensor | None = None
@property
@deprecated(
"""
Prefer using device seq_lens directly to avoid implicit H<>D sync.
If a CPU copy is needed, use `seq_lens.cpu()` instead.
Will be removed in a future release (v0.15.0)
"""
)
def seq_lens_cpu(self) -> torch.Tensor:
if self._seq_lens_cpu is None:
self._seq_lens_cpu = self.seq_lens.to("cpu")
return self._seq_lens_cpu
@property
@deprecated(
"""
Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full
async scheduling. If a CPU copy is needed, it can be derived from
query_start_loc_cpu and seq_lens.
Will be removed in a future release (v0.15.0)
"""
)
def num_computed_tokens_cpu(self) -> torch.Tensor:
if self._num_computed_tokens_cpu is None:
query_seq_lens = (
self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1]
)
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
return self._num_computed_tokens_cpu
def compute_num_computed_tokens(self) -> torch.Tensor:
"""Compute num_computed_tokens on device (seq_lens - query_lens)."""
if self._num_computed_tokens_cache is None:
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
self._num_computed_tokens_cache = self.seq_lens - query_lens
return self._num_computed_tokens_cache
# TODO(lucas): remove once we have FULL-CG spec-decode support
def unpadded(
self, num_actual_tokens: int, num_actual_reqs: int
) -> "CommonAttentionMetadata":
maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None
return CommonAttentionMetadata(
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
if self._seq_lens_cpu is not None
else None,
_num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs]
if self._num_computed_tokens_cpu is not None
else None,
num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len,
max_seq_len=self.max_seq_len,
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
slot_mapping=self.slot_mapping[:num_actual_tokens],
causal=self.causal,
logits_indices_padded=self.logits_indices_padded,
num_logits_indices=self.num_logits_indices,
encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens),
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
)
def slice_query_start_locs( def slice_query_start_locs(
query_start_loc: torch.Tensor, query_start_loc: torch.Tensor,
request_slice: slice, request_slice: slice,
...@@ -299,171 +178,6 @@ def split_attn_metadata( ...@@ -299,171 +178,6 @@ def split_attn_metadata(
return results return results
M = TypeVar("M")
class AttentionCGSupport(enum.Enum):
"""Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
ALWAYS = 3
"""Cudagraph always supported; supports mixed-prefill-decode"""
UNIFORM_BATCH = 2
"""Cudagraph supported for batches the only contain query lengths that are
the same, this can be used for spec-decode
i.e. "decodes" are 1 + num_speculative_tokens"""
UNIFORM_SINGLE_TOKEN_DECODE = 1
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
NEVER = 0
"""NO cudagraph support"""
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention (default: no).
# Do not access directly. Call get_cudagraph_support() instead.
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: int | None = None
# Does this backend/builder support updating the block table in existing
# metadata
supports_update_block_table: bool = False
@abstractmethod
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
self.kv_cache_spec = kv_cache_spec
self.layer_names = layer_names
self.vllm_config = vllm_config
self.device = device
@classmethod
def get_cudagraph_support(
cls: type["AttentionMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
"""Get the cudagraph support level of this builder class."""
return cls._cudagraph_support
def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int | None = 1,
supports_spec_as_decode: bool = False,
supports_dcp_with_varlen: bool = False,
) -> None:
self.reorder_batch_threshold = reorder_batch_threshold
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
# If the backend supports spec-as-decode kernels, then we can set
# the reorder_batch_threshold based on the number of speculative
# tokens from the config.
speculative_config = self.vllm_config.speculative_config
if (
speculative_config is not None
and speculative_config.num_speculative_tokens is not None
):
self.reorder_batch_threshold = max(
self.reorder_batch_threshold,
1 + speculative_config.num_speculative_tokens,
)
if (
self.vllm_config.parallel_config.decode_context_parallel_size > 1
and not supports_dcp_with_varlen
):
self.reorder_batch_threshold = 1
@abstractmethod
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> M:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
Args:
common_prefix_len: The length of the common prefix of the batch.
common_attn_metadata: The common attention metadata.
fast_build: The meta-data will prioritize speed of building over
then speed at execution. Can be used for spec-decode where the
result of a build call may only be used for few layers/iters.
"""
raise NotImplementedError
def update_block_table(
self,
metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M:
"""
Update the block table for the attention metadata.
Faster when theres multiple kv-cache groups that create virtually the
same metadata but just with different block tables.
Only needs to be implemented if supports_update_block_table is True.
"""
raise NotImplementedError
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> M:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return self.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> M:
"""
Build attention metadata for draft model. Uses build by default.
Args:
common_attn_metadata: The common attention metadata.
draft_index: The index of the current draft operation.
When speculating a chain of tokens, this index refers to the
draft attempt for the i-th token.
For tree-based attention, this index instead refers to the
draft attempt for the i-th level in the tree of tokens.
"""
return self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
fast_build=True,
)
def use_cascade_attention(
self,
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
use_local_attention: bool,
num_sms: int,
dcp_world_size: int,
) -> bool:
return False
@functools.lru_cache @functools.lru_cache
def get_kv_cache_layout(): def get_kv_cache_layout():
# Format specified by the code. # Format specified by the code.
...@@ -834,6 +548,9 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( ...@@ -834,6 +548,9 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
return common_attn_metadata return common_attn_metadata
M = TypeVar("M")
def subclass_attention_backend( def subclass_attention_backend(
name_prefix: str, name_prefix: str,
attention_backend_cls: type[AttentionBackend], attention_backend_cls: type[AttentionBackend],
......
...@@ -26,16 +26,16 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -26,16 +26,16 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.tree_attn import ( from vllm.v1.attention.backends.tree_attn import (
TreeAttentionMetadata, TreeAttentionMetadata,
TreeAttentionMetadataBuilder, TreeAttentionMetadataBuilder,
) )
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.sample.sampler import _SAMPLING_EPS
......
...@@ -7,8 +7,8 @@ import torch ...@@ -7,8 +7,8 @@ import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.utils import ( AttentionBackend,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
) )
......
...@@ -100,15 +100,15 @@ from vllm.utils.torch_utils import ( ...@@ -100,15 +100,15 @@ from vllm.utils.torch_utils import (
) )
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata,
MultipleOf, MultipleOf,
) )
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
create_fast_prefill_custom_backend, create_fast_prefill_custom_backend,
get_dcp_local_seq_lens, get_dcp_local_seq_lens,
reorder_batch_to_split_decodes_and_prefills, reorder_batch_to_split_decodes_and_prefills,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment