Unverified Commit 0a2c2dc3 authored by Jack Yang's avatar Jack Yang Committed by GitHub
Browse files

fixed mypy warnings for files vllm/v1/attention with TEMPORARY workaround (#31465)


Signed-off-by: default avatarZhuohao Yang <zy242@cornell.edu>
Co-authored-by: default avatarZhuohao Yang <zy242@cornell.edu>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent f09c5feb
...@@ -41,6 +41,7 @@ FILES = [ ...@@ -41,6 +41,7 @@ FILES = [
"vllm/usage", "vllm/usage",
"vllm/utils", "vllm/utils",
"vllm/worker", "vllm/worker",
"vllm/v1/attention",
"vllm/v1/core", "vllm/v1/core",
"vllm/v1/engine", "vllm/v1/engine",
"vllm/v1/executor", "vllm/v1/executor",
...@@ -60,7 +61,6 @@ SEPARATE_GROUPS = [ ...@@ -60,7 +61,6 @@ SEPARATE_GROUPS = [
"vllm/lora", "vllm/lora",
"vllm/model_executor", "vllm/model_executor",
# v1 related # v1 related
"vllm/v1/attention",
"vllm/v1/kv_offload", "vllm/v1/kv_offload",
"vllm/v1/spec_decode", "vllm/v1/spec_decode",
"vllm/v1/structured_output", "vllm/v1/structured_output",
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import torch import torch
...@@ -14,7 +15,7 @@ if TYPE_CHECKING: ...@@ -14,7 +15,7 @@ if TYPE_CHECKING:
from vllm.v1.attention.backends.utils import KVCacheLayoutType from vllm.v1.attention.backends.utils import KVCacheLayoutType
class AttentionType: class AttentionType(str, Enum):
""" """
Attention type. Attention type.
Use string to be compatible with `torch.compile`. Use string to be compatible with `torch.compile`.
...@@ -193,7 +194,7 @@ class AttentionBackend(ABC): ...@@ -193,7 +194,7 @@ class AttentionBackend(ABC):
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None", kv_cache_dtype: "CacheDType | None",
block_size: int | None, block_size: int,
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
...@@ -207,7 +208,7 @@ class AttentionBackend(ABC): ...@@ -207,7 +208,7 @@ class AttentionBackend(ABC):
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None", kv_cache_dtype: "CacheDType | None",
block_size: int | None, block_size: int,
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
...@@ -290,6 +291,11 @@ class AttentionLayer(Protocol): ...@@ -290,6 +291,11 @@ class AttentionLayer(Protocol):
class AttentionImpl(ABC, Generic[T]): class AttentionImpl(ABC, Generic[T]):
# Required attributes that all impls should have
num_heads: int
head_size: int
scale: float
# Whether the attention impl can return the softmax lse for decode. # Whether the attention impl can return the softmax lse for decode.
# Some features like decode context parallelism require the softmax lse. # Some features like decode context parallelism require the softmax lse.
can_return_lse_for_decode: bool = False can_return_lse_for_decode: bool = False
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend, AttentionImpl
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheSpec
...@@ -18,6 +18,8 @@ class AttentionLayerBase(ABC): ...@@ -18,6 +18,8 @@ class AttentionLayerBase(ABC):
from different layer types. from different layer types.
""" """
impl: "AttentionImpl"
@abstractmethod @abstractmethod
def get_attn_backend(self) -> type[AttentionBackend]: def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this layer.""" """Get the attention backend class for this layer."""
......
...@@ -167,7 +167,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -167,7 +167,7 @@ class FlashAttentionBackend(AttentionBackend):
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: CacheDType | None, kv_cache_dtype: CacheDType | None,
block_size: int, block_size: int | None,
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
...@@ -354,7 +354,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -354,7 +354,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
aot_schedule = False aot_schedule = False
max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible
if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size: if (
self.use_full_cuda_graph
and self.max_cudagraph_size is not None
and num_actual_tokens <= self.max_cudagraph_size
):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory # NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits, # usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore, # num_heads, num_tokens, head_size] are allocated. Therefore,
...@@ -599,6 +603,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -599,6 +603,9 @@ class FlashAttentionImpl(AttentionImpl):
We use torch's .expand() to avoid duplicating values We use torch's .expand() to avoid duplicating values
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
if output_scale is not None or output_block_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
...@@ -697,6 +704,11 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -697,6 +704,11 @@ class FlashAttentionImpl(AttentionImpl):
) )
return output return output
else: else:
sliding_window_size = (
list(self.sliding_window)
if self.sliding_window is not None
else None
)
flash_attn_varlen_func( flash_attn_varlen_func(
q=query[:num_actual_tokens], q=query[:num_actual_tokens],
k=key_cache, k=key_cache,
...@@ -709,7 +721,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -709,7 +721,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale, softmax_scale=self.scale,
causal=attn_metadata.causal, causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window, window_size=sliding_window_size,
block_table=block_table, block_table=block_table,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata, scheduler_metadata=scheduler_metadata,
...@@ -764,12 +776,19 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -764,12 +776,19 @@ class FlashAttentionImpl(AttentionImpl):
k_descale: torch.Tensor | None = None, k_descale: torch.Tensor | None = None,
v_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
cu_seqlens_q = attn_metadata.query_start_loc cu_seqlens_q = attn_metadata.query_start_loc
max_seqlen_q = attn_metadata.max_query_len max_seqlen_q = attn_metadata.max_query_len
block_table = attn_metadata.block_table block_table = attn_metadata.block_table
query = query.contiguous() query = query.contiguous()
query_across_dcp = get_dcp_group().all_gather(query, dim=1) query_across_dcp = get_dcp_group().all_gather(query, dim=1)
sliding_window_size = (
list(self.sliding_window) if self.sliding_window is not None else None
)
context_attn_out, context_lse = flash_attn_varlen_func( context_attn_out, context_lse = flash_attn_varlen_func(
q=query_across_dcp, q=query_across_dcp,
k=key_cache, k=key_cache,
...@@ -782,7 +801,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -782,7 +801,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale, softmax_scale=self.scale,
causal=False, causal=False,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window, window_size=sliding_window_size,
block_table=block_table, block_table=block_table,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
...@@ -813,7 +832,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -813,7 +832,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale, softmax_scale=self.scale,
causal=attn_metadata.causal, causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window, window_size=sliding_window_size,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
...@@ -850,6 +869,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -850,6 +869,10 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: Encoder attention metadata attn_metadata: Encoder attention metadata
layer: The attention layer layer: The attention layer
""" """
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
# For encoder attention, process FP8 quantization if needed # For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError( raise NotImplementedError(
...@@ -868,6 +891,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -868,6 +891,9 @@ class FlashAttentionImpl(AttentionImpl):
) )
# Call flash attention directly on Q, K, V tensors # Call flash attention directly on Q, K, V tensors
sliding_window_size = (
list(self.sliding_window) if self.sliding_window is not None else None
)
flash_attn_varlen_func( flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
...@@ -880,7 +906,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -880,7 +906,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale, softmax_scale=self.scale,
causal=False, # Encoder attention is bidirectional causal=False, # Encoder attention is bidirectional
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window, window_size=sliding_window_size,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape), q_descale=layer._q_scale.expand(descale_shape),
...@@ -1020,7 +1046,7 @@ def cascade_attention( ...@@ -1020,7 +1046,7 @@ def cascade_attention(
max_seqlen_k=common_prefix_len, max_seqlen_k=common_prefix_len,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=False, causal=False,
window_size=sliding_window, window_size=list(sliding_window),
block_table=block_table[:1], block_table=block_table[:1],
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
...@@ -1048,7 +1074,7 @@ def cascade_attention( ...@@ -1048,7 +1074,7 @@ def cascade_attention(
max_seqlen_k=max_kv_len - common_prefix_len, max_seqlen_k=max_kv_len - common_prefix_len,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=True, causal=True,
window_size=sliding_window, window_size=list(sliding_window),
block_table=block_table[:, num_common_kv_blocks:], block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
......
...@@ -113,6 +113,9 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl): ...@@ -113,6 +113,9 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
We use torch's .expand() to avoid duplicating values We use torch's .expand() to avoid duplicating values
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
if output_scale is not None or output_block_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
...@@ -214,6 +217,11 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl): ...@@ -214,6 +217,11 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
) )
return output return output
else: else:
sliding_window_size = (
list(self.sliding_window)
if self.sliding_window is not None
else None
)
flash_attn_varlen_func( flash_attn_varlen_func(
q=query[:num_actual_tokens], q=query[:num_actual_tokens],
k=key_cache, k=key_cache,
...@@ -226,7 +234,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl): ...@@ -226,7 +234,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
softmax_scale=self.scale, softmax_scale=self.scale,
causal=attn_metadata.causal, causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window, window_size=sliding_window_size,
block_table=block_table, block_table=block_table,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata, scheduler_metadata=scheduler_metadata,
......
...@@ -530,11 +530,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -530,11 +530,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self._decode_wrappers_cudagraph: dict[ self._decode_wrappers_cudagraph: dict[
int, BatchDecodeWithPagedKVCacheWrapper int, BatchDecodeWithPagedKVCacheWrapper
] = {} ] = {}
self._decode_cudagraph_max_bs = min( self._decode_cudagraph_max_bs = (1 + num_spec_tokens) * max_num_reqs
(1 + num_spec_tokens) * max_num_reqs, if self.compilation_config.max_cudagraph_capture_size is not None:
self.compilation_config.max_cudagraph_capture_size, self._decode_cudagraph_max_bs = min(
) self._decode_cudagraph_max_bs,
self.compilation_config.max_cudagraph_capture_size,
)
try: try:
self.dcp_world_size = get_dcp_group().world_size self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group self.dcp_rank = get_dcp_group().rank_in_group
......
...@@ -215,7 +215,7 @@ def physical_to_logical_mapping( ...@@ -215,7 +215,7 @@ def physical_to_logical_mapping(
) )
# Only process valid blocks to avoid garbage values # Only process valid blocks to avoid garbage values
num_blocks_per_seq = cdiv(seq_lens, block_size) num_blocks_per_seq: torch.Tensor = cdiv(seq_lens, block_size)
mask = ( mask = (
torch.arange(max_num_blocks, device=device)[None, :] torch.arange(max_num_blocks, device=device)[None, :]
< num_blocks_per_seq[:, None] < num_blocks_per_seq[:, None]
......
...@@ -75,8 +75,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -75,8 +75,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
if self.speculative_config: if self.speculative_config:
self.num_spec = self.speculative_config.num_speculative_tokens assert self.speculative_config.num_speculative_tokens is not None
self.num_spec: int = self.speculative_config.num_speculative_tokens
else: else:
self.num_spec = 0 self.num_spec = 0
self.use_spec_decode = self.num_spec > 0 self.use_spec_decode = self.num_spec > 0
...@@ -85,10 +87,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -85,10 +87,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
self.use_full_cuda_graph = ( self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs() self.compilation_config.cudagraph_mode.has_full_cudagraphs()
) )
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1), self.decode_cudagraph_max_bs = (
self.compilation_config.max_cudagraph_capture_size, self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1)
) )
if self.compilation_config.max_cudagraph_capture_size is not None:
self.decode_cudagraph_max_bs = min(
self.decode_cudagraph_max_bs,
self.compilation_config.max_cudagraph_capture_size,
)
self.spec_state_indices_tensor = torch.empty( self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1), (self.decode_cudagraph_max_bs, self.num_spec + 1),
......
...@@ -123,10 +123,11 @@ class Mamba2AttentionMetadataBuilder( ...@@ -123,10 +123,11 @@ class Mamba2AttentionMetadataBuilder(
device: torch.device, device: torch.device,
): ):
super().__init__(kv_cache_spec, layer_names, vllm_config, device) super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() chunk_size = vllm_config.model_config.get_mamba_chunk_size()
assert self.chunk_size is not None, ( assert chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models" "chunk_size needs to be set in the model config for Mamba2 models"
) )
self.chunk_size: int = chunk_size
def _compute_chunk_metadata( def _compute_chunk_metadata(
self, self,
......
...@@ -69,10 +69,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -69,10 +69,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
assert isinstance(kv_cache_spec, MambaSpec) assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = min( self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs
self.vllm_config.scheduler_config.max_num_seqs, if self.compilation_config.max_cudagraph_capture_size is not None:
self.compilation_config.max_cudagraph_capture_size, self.decode_cudagraph_max_bs = min(
) self.decode_cudagraph_max_bs,
self.compilation_config.max_cudagraph_capture_size,
)
if self.vllm_config.cache_config.enable_prefix_caching: if self.vllm_config.cache_config.enable_prefix_caching:
self.state_indices_tensor = torch.empty( self.state_indices_tensor = torch.empty(
...@@ -150,9 +152,13 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -150,9 +152,13 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1 cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
) )
# -1 in case it's non-computed and causes later issues with indexing # -1 in case it's non-computed and causes later issues with indexing
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0) block_idx_last_computed_token = torch.clamp(
block_idx_last_computed_token, min=0
)
# -1 in the case we have a padded request (0 seq-len) # -1 in the case we have a padded request (0 seq-len)
block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0) block_idx_last_scheduled_token = torch.clamp(
block_idx_last_scheduled_token, min=0
)
return ( return (
block_idx_last_computed_token, block_idx_last_computed_token,
......
...@@ -62,7 +62,7 @@ class AiterTritonMLAImpl(AiterMLAImpl): ...@@ -62,7 +62,7 @@ class AiterTritonMLAImpl(AiterMLAImpl):
k, k,
v, v,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
return_lse=return_softmax_lse, return_softmax_lse=return_softmax_lse,
**kwargs, **kwargs,
) )
# Transpose the LSE if Triton MHA is used: # Transpose the LSE if Triton MHA is used:
......
...@@ -202,6 +202,7 @@ from vllm._aiter_ops import rocm_aiter_ops ...@@ -202,6 +202,7 @@ from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionLayer, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl, MLAAttentionImpl,
) )
from vllm.attention.backends.utils import get_mla_dims from vllm.attention.backends.utils import get_mla_dims
...@@ -251,13 +252,15 @@ class QueryLenSupport(Enum): ...@@ -251,13 +252,15 @@ class QueryLenSupport(Enum):
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
flash_attn_varlen_func,
)
is_vllm_fa = True is_vllm_fa = True
except ImportError: except ImportError:
# For rocm use upstream flash attention # For rocm use upstream flash attention
if current_platform.is_rocm(): if current_platform.is_rocm():
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
is_vllm_fa = False is_vllm_fa = False
try: try:
...@@ -386,7 +389,7 @@ D = TypeVar("D", bound=MLACommonDecodeMetadata) ...@@ -386,7 +389,7 @@ D = TypeVar("D", bound=MLACommonDecodeMetadata)
@dataclass @dataclass
class MLACommonMetadata(Generic[D]): class MLACommonMetadata(AttentionMetadata, Generic[D]):
"""Metadata for MLACommon. """Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
...@@ -434,7 +437,7 @@ class MLACommonMetadata(Generic[D]): ...@@ -434,7 +437,7 @@ class MLACommonMetadata(Generic[D]):
M = TypeVar("M", bound=MLACommonMetadata) M = TypeVar("M", bound=MLACommonMetadata)
A = TypeVar("A") A = TypeVar("A", bound=AttentionMetadata)
def use_flashinfer_prefill() -> bool: def use_flashinfer_prefill() -> bool:
...@@ -617,7 +620,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -617,7 +620,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = [] self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = []
self._global_hyperparameters = infer_global_hyperparameters( self._global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) # type: ignore[type-abstract]
) )
if self._use_trtllm_ragged_prefill: if self._use_trtllm_ragged_prefill:
...@@ -874,7 +877,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -874,7 +877,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
) )
# Note(qcs): The max local context lengths # Note(qcs): The max local context lengths
# padded to `dcp_local_block_size`. # padded to `dcp_local_block_size`.
padded_local_context_lens_cpu = ( padded_local_context_lens_cpu: torch.Tensor = (
cdiv( cdiv(
context_lens_cpu, context_lens_cpu,
self.dcp_virtual_block_size, self.dcp_virtual_block_size,
...@@ -1171,7 +1174,9 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): ...@@ -1171,7 +1174,9 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
) )
def get_and_maybe_dequant_weights(layer: LinearBase): def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod): if layer.quant_method is not None and not isinstance(
layer.quant_method, UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3) # NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye( eye = torch.eye(
layer.input_size_per_partition, layer.input_size_per_partition,
...@@ -1327,12 +1332,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1327,12 +1332,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# v with 0s to match the qk head dim for attention backends that do # v with 0s to match the qk head dim for attention backends that do
# not support different headdims # not support different headdims
# We don't need to pad V if we are on a hopper system with FA3 # We don't need to pad V if we are on a hopper system with FA3
device_capability = current_platform.get_device_capability()
self._pad_v = self.vllm_flash_attn_version is None or not ( self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3 self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9 and device_capability is not None
and device_capability[0] == 9
) )
self.dcp_world_size: int | None = None self.dcp_world_size: int = -1
self.chunked_prefill_workspace_size = ( self.chunked_prefill_workspace_size = (
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
...@@ -1583,7 +1590,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1583,7 +1590,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) )
def get_and_maybe_dequant_weights(layer: LinearBase): def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod): if layer.quant_method is not None and not isinstance(
layer.quant_method, UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3) # NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye( eye = torch.eye(
layer.input_size_per_partition, layer.input_size_per_partition,
...@@ -1875,7 +1884,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1875,7 +1884,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) -> None: ) -> None:
# TODO (zyongye): Prefill function here # TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
assert self.dcp_world_size is not None assert self.dcp_world_size != -1
has_context = attn_metadata.prefill.chunked_context is not None has_context = attn_metadata.prefill.chunked_context is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
...@@ -1975,7 +1984,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1975,7 +1984,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# same expert outputs. # same expert outputs.
return output.fill_(0) return output.fill_(0)
if self.dcp_world_size is None: if self.dcp_world_size == -1:
self.dcp_world_size = get_dcp_group().world_size self.dcp_world_size = get_dcp_group().world_size
fp8_attention = self.kv_cache_dtype.startswith("fp8") fp8_attention = self.kv_cache_dtype.startswith("fp8")
......
...@@ -33,7 +33,10 @@ from vllm.v1.attention.backends.mla.common import ( ...@@ -33,7 +33,10 @@ from vllm.v1.attention.backends.mla.common import (
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
flash_attn_varlen_func,
get_scheduler_metadata,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -181,7 +184,11 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -181,7 +184,11 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# For Flash Attention MLA + full cudagraph # For Flash Attention MLA + full cudagraph
max_num_splits = 0 max_num_splits = 0
if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size: if (
self.use_full_cuda_graph
and self.max_cudagraph_size is not None
and num_decode_tokens <= self.max_cudagraph_size
):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory # NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits, # usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore, # num_heads, num_tokens, head_size] are allocated. Therefore,
......
...@@ -10,6 +10,7 @@ from vllm import _custom_ops as ops ...@@ -10,6 +10,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionLayer, AttentionLayer,
AttentionMetadata,
MultipleOf, MultipleOf,
) )
from vllm.attention.backends.utils import get_mla_dims from vllm.attention.backends.utils import get_mla_dims
...@@ -124,7 +125,7 @@ class FlashMLASparseBackend(AttentionBackend): ...@@ -124,7 +125,7 @@ class FlashMLASparseBackend(AttentionBackend):
@dataclass @dataclass
class FlashMLASparseMetadata: class FlashMLASparseMetadata(AttentionMetadata):
num_reqs: int num_reqs: int
max_query_len: int max_query_len: int
max_seq_len: int max_seq_len: int
...@@ -718,7 +719,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -718,7 +719,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
) )
self.softmax_scale = scale self.softmax_scale = scale
assert indexer is not None assert indexer is not None
self.topk_indices_buffer = indexer.topk_indices_buffer self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
self.padding = 128 if current_platform.is_device_capability_family(100) else 64 self.padding = 128 if current_platform.is_device_capability_family(100) else 64
if kv_cache_dtype == "fp8_ds_mla": if kv_cache_dtype == "fp8_ds_mla":
...@@ -980,6 +981,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -980,6 +981,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
q = q[:num_actual_toks, ...] q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...]
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks] topk_indices = self.topk_indices_buffer[:num_actual_toks]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
......
...@@ -236,7 +236,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -236,7 +236,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
k=k, k=k,
v=v, v=v,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
return_lse=return_softmax_lse, return_softmax_lse=return_softmax_lse,
**kwargs, **kwargs,
) )
...@@ -251,6 +251,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -251,6 +251,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
assert attn_metadata.decode.max_qo_len is not None
if type(q) is tuple: if type(q) is tuple:
q = torch.cat(q, dim=-1) q = torch.cat(q, dim=-1)
......
...@@ -43,7 +43,7 @@ class ROCMAiterMLASparseBackend(AttentionBackend): ...@@ -43,7 +43,7 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
return "ROCM_AITER_MLA_SPARSE" return "ROCM_AITER_MLA_SPARSE"
@staticmethod @staticmethod
def get_metadata_cls() -> type[AttentionMetadata]: def get_metadata_cls() -> type["ROCMAiterMLASparseMetadata"]:
return ROCMAiterMLASparseMetadata return ROCMAiterMLASparseMetadata
@staticmethod @staticmethod
...@@ -74,7 +74,7 @@ class ROCMAiterMLASparseBackend(AttentionBackend): ...@@ -74,7 +74,7 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
@dataclass @dataclass
class ROCMAiterMLASparseMetadata: class ROCMAiterMLASparseMetadata(AttentionMetadata):
num_reqs: int num_reqs: int
max_query_len: int max_query_len: int
max_seq_len: int max_seq_len: int
...@@ -223,7 +223,7 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]): ...@@ -223,7 +223,7 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
) )
self.softmax_scale = scale self.softmax_scale = scale
assert indexer is not None assert indexer is not None
self.topk_indices_buffer = indexer.topk_indices_buffer self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
def _forward_bf16_kv( def _forward_bf16_kv(
...@@ -294,6 +294,7 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]): ...@@ -294,6 +294,7 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
# Convert from (N, B, L) to (B, N, L) # Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1) ql_nope = ql_nope.transpose(0, 1)
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks] topk_indices = self.topk_indices_buffer[:num_actual_toks]
topk_indices_global = triton_convert_req_index_to_global_index( topk_indices_global = triton_convert_req_index_to_global_index(
......
...@@ -155,7 +155,9 @@ class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadat ...@@ -155,7 +155,9 @@ class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadat
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
spec_config = vllm_config.speculative_config spec_config = vllm_config.speculative_config
spec_token_tree = (spec := spec_config) and spec.speculative_token_tree spec_token_tree: str | None = None
if spec := spec_config:
spec_token_tree = spec.speculative_token_tree
tree_choices: list[tuple[int, ...]] = ( tree_choices: list[tuple[int, ...]] = (
ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)] ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)]
) )
......
...@@ -469,6 +469,7 @@ def get_kv_cache_layout(): ...@@ -469,6 +469,7 @@ def get_kv_cache_layout():
# Format specified by the code. # Format specified by the code.
global _KV_CACHE_LAYOUT_OVERRIDE global _KV_CACHE_LAYOUT_OVERRIDE
cache_layout: Literal["NHD", "HND"] | None = None
if _KV_CACHE_LAYOUT_OVERRIDE is not None: if _KV_CACHE_LAYOUT_OVERRIDE is not None:
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
logger.info_once( logger.info_once(
...@@ -524,7 +525,11 @@ def get_per_layer_parameters( ...@@ -524,7 +525,11 @@ def get_per_layer_parameters(
to use during `plan`. to use during `plan`.
""" """
layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names) layers = get_layers_from_vllm_config(
vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
layer_names,
)
per_layer_params: dict[str, PerLayerParameters] = {} per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items(): for key, layer in layers.items():
...@@ -1125,7 +1130,7 @@ class KVSharingFastPrefillMetadata(Protocol): ...@@ -1125,7 +1130,7 @@ class KVSharingFastPrefillMetadata(Protocol):
def create_fast_prefill_custom_backend( def create_fast_prefill_custom_backend(
prefix: str, prefix: str,
underlying_attn_backend: AttentionBackend, underlying_attn_backend: type[AttentionBackend],
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
underlying_builder = underlying_attn_backend.get_builder_cls() underlying_builder = underlying_attn_backend.get_builder_cls()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment