Unverified Commit 20228cb8 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[3/N][Attention] Move AttentionMetadata-related code from utils.py to backend.py (#32054)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 7c0d3c51
...@@ -149,7 +149,7 @@ The CUDA Graphs wrapper no longer manages the warm-up logic. The warm-up process ...@@ -149,7 +149,7 @@ The CUDA Graphs wrapper no longer manages the warm-up logic. The warm-up process
## CUDA Graphs Compatibility of Attention Backends ## CUDA Graphs Compatibility of Attention Backends
To signal the CUDA Graphs compatibility of the attention backends, we introduce a new enum type [AttentionCGSupport][vllm.v1.attention.backends.utils.AttentionCGSupport], which is an enum type that tracks the capability of the attention backend to support CUDA Graphs. The value is sorted in the order of the capability, i.e., `ALWAYS`> `UNIFORM_BATCH`> `UNIFORM_SINGLE_TOKEN_DECODE`> `NEVER`. To signal the CUDA Graphs compatibility of the attention backends, we introduce a new enum type [AttentionCGSupport][vllm.v1.attention.backend.AttentionCGSupport], which is an enum type that tracks the capability of the attention backend to support CUDA Graphs. The value is sorted in the order of the capability, i.e., `ALWAYS`> `UNIFORM_BATCH`> `UNIFORM_SINGLE_TOKEN_DECODE`> `NEVER`.
```python ```python
class AttentionCGSupport(enum.Enum): class AttentionCGSupport(enum.Enum):
......
...@@ -23,10 +23,9 @@ from vllm.utils.torch_utils import ( ...@@ -23,10 +23,9 @@ from vllm.utils.torch_utils import (
is_torch_equal_or_newer, is_torch_equal_or_newer,
set_random_seed, set_random_seed,
) )
from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backend import AttentionType, CommonAttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
set_kv_cache_layout, set_kv_cache_layout,
) )
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
......
...@@ -22,10 +22,10 @@ from vllm.config.vllm import set_current_vllm_config ...@@ -22,10 +22,10 @@ from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
from vllm.v1.attention.backends.mla.common import QueryLenSupport from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
......
...@@ -18,12 +18,12 @@ from vllm.config import ( ...@@ -18,12 +18,12 @@ from vllm.config import (
VllmConfig, VllmConfig,
) )
from vllm.config.model import ModelDType from vllm.config.model import ModelDType
from vllm.v1.attention.backend import AttentionImpl from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.registry import AttentionBackendEnum AttentionImpl,
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
) )
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
......
...@@ -19,7 +19,7 @@ def sync_tracker(): ...@@ -19,7 +19,7 @@ def sync_tracker():
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
lazy init syncs. Prints stack traces immediately when syncs occur. lazy init syncs. Prints stack traces immediately when syncs occur.
""" """
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backend import CommonAttentionMetadata
# Shared counter for cross-process communication (inherited by fork) # Shared counter for cross-process communication (inherited by fork)
sync_count = multiprocessing.Value("i", 0) sync_count = multiprocessing.Value("i", 0)
......
...@@ -12,9 +12,9 @@ from tests.v1.attention.utils import ( ...@@ -12,9 +12,9 @@ from tests.v1.attention.utils import (
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.config import ParallelConfig, SpeculativeConfig from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
if not is_flash_attn_varlen_func_available(): if not is_flash_attn_varlen_func_available():
pytest.skip( pytest.skip(
......
...@@ -8,11 +8,13 @@ from vllm.attention.layer import Attention ...@@ -8,11 +8,13 @@ from vllm.attention.layer import Attention
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.utils import ( AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
make_local_attention_virtual_batches, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_backend,
) )
......
...@@ -14,9 +14,9 @@ from vllm.v1.attention.backend import ( ...@@ -14,9 +14,9 @@ from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
CommonAttentionMetadata,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.attention.selector import get_attn_backend from vllm.v1.attention.selector import get_attn_backend
......
...@@ -12,9 +12,9 @@ from vllm.v1.attention.backend import ( ...@@ -12,9 +12,9 @@ from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
CommonAttentionMetadata,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.attention.selector import get_attn_backend from vllm.v1.attention.selector import get_attn_backend
......
...@@ -15,9 +15,9 @@ from vllm.v1.attention.backend import ( ...@@ -15,9 +15,9 @@ from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
CommonAttentionMetadata,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
......
...@@ -16,10 +16,10 @@ from vllm.v1.attention.backend import ( ...@@ -16,10 +16,10 @@ from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
CommonAttentionMetadata,
) )
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend_with_overrides, subclass_attention_backend_with_overrides,
) )
from vllm.v1.attention.selector import get_attn_backend from vllm.v1.attention.selector import get_attn_backend
......
...@@ -2,17 +2,22 @@ ...@@ -2,17 +2,22 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import numpy as np
import torch import torch
from typing_extensions import deprecated
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType from vllm.v1.attention.backends.utils import KVCacheLayoutType
from vllm.v1.kv_cache_interface import AttentionSpec
class AttentionType(str, Enum): class AttentionType(str, Enum):
...@@ -271,6 +276,288 @@ class AttentionMetadata: ...@@ -271,6 +276,288 @@ class AttentionMetadata:
T = TypeVar("T", bound=AttentionMetadata) T = TypeVar("T", bound=AttentionMetadata)
@dataclass
class CommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
"""
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""
num_reqs: int
"""Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
max_seq_len: int
"""Longest context length (may be an upper bound)"""
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
causal: bool = True
# Needed by FastPrefillAttentionBuilder
logits_indices_padded: torch.Tensor | None = None
num_logits_indices: int | None = None
# Needed by CrossAttentionBuilder
encoder_seq_lens: torch.Tensor | None = None
encoder_seq_lens_cpu: np.ndarray | None = None
dcp_local_seq_lens: torch.Tensor | None = None
dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
_seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None
_num_computed_tokens_cache: torch.Tensor | None = None
@property
@deprecated(
"""
Prefer using device seq_lens directly to avoid implicit H<>D sync.
If a CPU copy is needed, use `seq_lens.cpu()` instead.
Will be removed in a future release (v0.15.0)
"""
)
def seq_lens_cpu(self) -> torch.Tensor:
if self._seq_lens_cpu is None:
self._seq_lens_cpu = self.seq_lens.to("cpu")
return self._seq_lens_cpu
@property
@deprecated(
"""
Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full
async scheduling. If a CPU copy is needed, it can be derived from
query_start_loc_cpu and seq_lens.
Will be removed in a future release (v0.15.0)
"""
)
def num_computed_tokens_cpu(self) -> torch.Tensor:
if self._num_computed_tokens_cpu is None:
query_seq_lens = (
self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1]
)
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
return self._num_computed_tokens_cpu
def compute_num_computed_tokens(self) -> torch.Tensor:
"""Compute num_computed_tokens on device (seq_lens - query_lens)."""
if self._num_computed_tokens_cache is None:
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
self._num_computed_tokens_cache = self.seq_lens - query_lens
return self._num_computed_tokens_cache
# TODO(lucas): remove once we have FULL-CG spec-decode support
def unpadded(
self, num_actual_tokens: int, num_actual_reqs: int
) -> "CommonAttentionMetadata":
maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None
return CommonAttentionMetadata(
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
if self._seq_lens_cpu is not None
else None,
_num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs]
if self._num_computed_tokens_cpu is not None
else None,
num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len,
max_seq_len=self.max_seq_len,
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
slot_mapping=self.slot_mapping[:num_actual_tokens],
causal=self.causal,
logits_indices_padded=self.logits_indices_padded,
num_logits_indices=self.num_logits_indices,
encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens),
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
)
M = TypeVar("M")
class AttentionCGSupport(Enum):
"""Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
ALWAYS = 3
"""Cudagraph always supported; supports mixed-prefill-decode"""
UNIFORM_BATCH = 2
"""Cudagraph supported for batches the only contain query lengths that are
the same, this can be used for spec-decode
i.e. "decodes" are 1 + num_speculative_tokens"""
UNIFORM_SINGLE_TOKEN_DECODE = 1
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
NEVER = 0
"""NO cudagraph support"""
class AttentionMetadataBuilder(ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention (default: no).
# Do not access directly. Call get_cudagraph_support() instead.
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: int | None = None
# Does this backend/builder support updating the block table in existing
# metadata
supports_update_block_table: bool = False
@abstractmethod
def __init__(
self,
kv_cache_spec: "AttentionSpec",
layer_names: list[str],
vllm_config: "VllmConfig",
device: torch.device,
):
self.kv_cache_spec = kv_cache_spec
self.layer_names = layer_names
self.vllm_config = vllm_config
self.device = device
@classmethod
def get_cudagraph_support(
cls: type["AttentionMetadataBuilder"],
vllm_config: "VllmConfig",
kv_cache_spec: "AttentionSpec",
) -> AttentionCGSupport:
"""Get the cudagraph support level of this builder class."""
return cls._cudagraph_support
def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int | None = 1,
supports_spec_as_decode: bool = False,
supports_dcp_with_varlen: bool = False,
) -> None:
self.reorder_batch_threshold = reorder_batch_threshold
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
# If the backend supports spec-as-decode kernels, then we can set
# the reorder_batch_threshold based on the number of speculative
# tokens from the config.
speculative_config = self.vllm_config.speculative_config
if (
speculative_config is not None
and speculative_config.num_speculative_tokens is not None
):
self.reorder_batch_threshold = max(
self.reorder_batch_threshold,
1 + speculative_config.num_speculative_tokens,
)
if (
self.vllm_config.parallel_config.decode_context_parallel_size > 1
and not supports_dcp_with_varlen
):
self.reorder_batch_threshold = 1
@abstractmethod
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> M:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
Args:
common_prefix_len: The length of the common prefix of the batch.
common_attn_metadata: The common attention metadata.
fast_build: The meta-data will prioritize speed of building over
then speed at execution. Can be used for spec-decode where the
result of a build call may only be used for few layers/iters.
"""
raise NotImplementedError
def update_block_table(
self,
metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M:
"""
Update the block table for the attention metadata.
Faster when theres multiple kv-cache groups that create virtually the
same metadata but just with different block tables.
Only needs to be implemented if supports_update_block_table is True.
"""
raise NotImplementedError
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> M:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return self.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> M:
"""
Build attention metadata for draft model. Uses build by default.
Args:
common_attn_metadata: The common attention metadata.
draft_index: The index of the current draft operation.
When speculating a chain of tokens, this index refers to the
draft attempt for the i-th token.
For tree-based attention, this index instead refers to the
draft attempt for the i-th level in the tree of tokens.
"""
return self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
fast_build=True,
)
def use_cascade_attention(
self,
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
use_local_attention: bool,
num_sms: int,
dcp_world_size: int,
) -> bool:
return False
class AttentionLayer(Protocol): class AttentionLayer(Protocol):
_q_scale: torch.Tensor _q_scale: torch.Tensor
_k_scale: torch.Tensor _k_scale: torch.Tensor
......
...@@ -13,12 +13,12 @@ from vllm.v1.attention.backend import ( ...@@ -13,12 +13,12 @@ from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionLayer, AttentionLayer,
AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills, split_decodes_and_prefills,
) )
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
......
...@@ -41,10 +41,12 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -41,10 +41,12 @@ from vllm.model_executor.layers.batch_invariant import (
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
get_dcp_local_seq_lens, get_dcp_local_seq_lens,
get_kv_cache_layout, get_kv_cache_layout,
) )
......
...@@ -43,14 +43,14 @@ from vllm.utils.platform_utils import is_pin_memory_available ...@@ -43,14 +43,14 @@ from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import is_strictly_contiguous from vllm.utils.torch_utils import is_strictly_contiguous
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport,
AttentionImpl, AttentionImpl,
AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata,
MultipleOf, MultipleOf,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
KVCacheLayoutType, KVCacheLayoutType,
get_dcp_local_seq_lens, get_dcp_local_seq_lens,
get_kv_cache_layout, get_kv_cache_layout,
......
...@@ -32,12 +32,10 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer ...@@ -32,12 +32,10 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionType,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
is_quantized_kv_cache,
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
......
...@@ -7,12 +7,14 @@ from dataclasses import dataclass ...@@ -7,12 +7,14 @@ from dataclasses import dataclass
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.utils import ( AttentionBackend,
PAD_SLOT_ID,
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
compute_causal_conv1d_metadata, compute_causal_conv1d_metadata,
split_decodes_and_prefills, split_decodes_and_prefills,
) )
......
...@@ -5,13 +5,13 @@ from dataclasses import dataclass ...@@ -5,13 +5,13 @@ from dataclasses import dataclass
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.utils import ( AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
split_decodes_and_prefills,
) )
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
......
...@@ -7,14 +7,11 @@ import torch ...@@ -7,14 +7,11 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
from vllm.v1.attention.backends.mamba_attn import ( from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata, BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder, BaseMambaAttentionMetadataBuilder,
) )
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
......
...@@ -10,11 +10,13 @@ import torch ...@@ -10,11 +10,13 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backend import (
PAD_SLOT_ID,
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
compute_causal_conv1d_metadata, compute_causal_conv1d_metadata,
split_decodes_and_prefills, split_decodes_and_prefills,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment