Unverified Commit aaa901ad authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Move MLA `forward` from backend to layer (#33284)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 010ec0c3
...@@ -274,11 +274,157 @@ class MockAttentionLayer: ...@@ -274,11 +274,157 @@ class MockAttentionLayer:
raise NotImplementedError raise NotImplementedError
class MockSparseMLAAttentionLayer:
"""A mock sparse MLA attention layer for testing.
Sparse MLA implementations only support forward_mqa (decode-style attention)
for all tokens, so this class only implements that path.
Unlike regular MLA impls, sparse MLA impls don't have W_UK_T and W_UV
attributes. These transformations are done by the layer (MLAAttention),
not the impl. This mock layer accepts these weight matrices directly.
"""
def __init__(
self,
impl,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
kv_lora_rank: int,
device: torch.device,
W_UK: torch.Tensor,
W_UV: torch.Tensor,
):
self.impl = impl
self.num_heads = num_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.kv_lora_rank = kv_lora_rank
# Compute weight matrices in the format expected by forward_impl
# W_UK shape: (L, N, P) -> W_UK_T shape: (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
# W_UV shape: (L, N, V) -> (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Scale attributes needed by attention backends
self._q_scale = torch.tensor(1.0, device=device)
self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device)
self._prob_scale = torch.tensor(1.0, device=device)
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
def forward_impl(
self,
q: torch.Tensor,
kv_c: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata,
output: torch.Tensor,
) -> torch.Tensor:
"""Forward for sparse MLA - uses forward_mqa for all tokens."""
# Write to KV cache
kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto")
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
kv_c,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=self._k_scale,
)
num_tokens = q.shape[0]
# Sparse MLA uses forward_mqa for all tokens
# Split q into nope and pe parts
mqa_q_nope, mqa_q_pe = q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
# Convert from (B, N, P) to (N, B, P)
mqa_q_nope = mqa_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
mqa_ql_nope = torch.bmm(mqa_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope = mqa_ql_nope.transpose(0, 1)
# Pass as tuple to forward_mqa
mqa_q = (mqa_ql_nope, mqa_q_pe)
attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)
# v_up projection: multiply by W_UV
# attn_out shape: (B, N, L) where L = kv_lora_rank
# W_UV shape: (N, L, V)
# output shape: (B, N, V) -> flatten to (B, N*V)
decode_output = torch.bmm(attn_out.transpose(0, 1), self.W_UV).transpose(0, 1)
output[:num_tokens] = decode_output.reshape(
num_tokens, self.num_heads * self.v_head_dim
)
return output
class MockMLAAttentionLayer(AttentionLayerBase): class MockMLAAttentionLayer(AttentionLayerBase):
"""A mock MLA attention layer for populating static_forward_context.""" """A mock MLA attention layer for testing.
This replicates the forward_impl logic from MLAAttention to allow
testing MLA backends without the full layer infrastructure.
The W_UK_T and W_UV weight matrices are created on the layer (like in
MLAAttention.process_weights_after_loading), not on the impl.
"""
def __init__(self, impl): def __init__(
self,
impl,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
kv_lora_rank: int,
device: torch.device,
kv_b_proj,
):
self.impl = impl self.impl = impl
self.num_heads = num_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.kv_lora_rank = kv_lora_rank
# Compute weight matrices from kv_b_proj (like MLAAttention does)
# This replicates MLAAttention.process_weights_after_loading logic
kv_b_proj_weight = kv_b_proj.weight.T
kv_b_proj_weight = kv_b_proj_weight.view(
kv_lora_rank,
num_heads,
qk_nope_head_dim + v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split([qk_nope_head_dim, v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
# Scale attributes needed by attention backends
self._q_scale = torch.tensor(1.0, device=device)
self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device)
self._prob_scale = torch.tensor(1.0, device=device)
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
def get_attn_backend(self): def get_attn_backend(self):
raise NotImplementedError raise NotImplementedError
...@@ -286,6 +432,83 @@ class MockMLAAttentionLayer(AttentionLayerBase): ...@@ -286,6 +432,83 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def get_kv_cache_spec(self, vllm_config): def get_kv_cache_spec(self, vllm_config):
raise NotImplementedError raise NotImplementedError
def forward_impl(
self,
q: torch.Tensor,
kv_c: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata,
output: torch.Tensor,
) -> torch.Tensor:
"""Replicates MLAAttention.forward_impl logic for testing."""
# Write to KV cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
kv_c,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype="auto",
scale=self._k_scale,
)
# Determine decode vs prefill split
num_decode_tokens = attn_metadata.num_decode_tokens or 0
has_decode = (attn_metadata.num_decodes or 0) > 0
has_prefill = (attn_metadata.num_prefills or 0) > 0
# Run prefill with forward_mha
if has_prefill:
prefill_q = q[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c = kv_c[num_decode_tokens:]
self.impl.forward_mha(
prefill_q,
prefill_k_c,
prefill_k_pe,
kv_cache,
attn_metadata,
self._k_scale,
output=output[num_decode_tokens:],
)
# Run decode with forward_mqa
if has_decode:
decode_q = q[:num_decode_tokens]
# Split q into nope and pe parts
mqa_q_nope, mqa_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
# Convert from (B, N, P) to (N, B, P)
mqa_q_nope = mqa_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
mqa_ql_nope = torch.bmm(mqa_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope = mqa_ql_nope.transpose(0, 1)
# Pass as tuple to forward_mqa
mqa_q = (mqa_ql_nope, mqa_q_pe)
attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)
# v_up projection: multiply by W_UV
# attn_out shape: (B, N, L) where L = kv_lora_rank
# W_UV shape: (N, L, V)
# output shape: (B, N, V) -> flatten to (B, N*V)
decode_output = torch.bmm(attn_out.transpose(0, 1), self.W_UV).transpose(
0, 1
)
output[:num_decode_tokens] = decode_output.reshape(
num_decode_tokens, self.num_heads * self.v_head_dim
)
return output
def run_attention_backend( def run_attention_backend(
backend: AttentionBackendEnum, backend: AttentionBackendEnum,
...@@ -340,14 +563,31 @@ def run_attention_backend( ...@@ -340,14 +563,31 @@ def run_attention_backend(
kv_b_proj=mock_kv_b_proj, kv_b_proj=mock_kv_b_proj,
) )
# Process weights to create W_UK_T and W_UV attributes needed by MLA # Process weights on the impl
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
impl.process_weights_after_loading(act_dtype) impl.process_weights_after_loading(act_dtype)
# Initialize DCP attributes (normally set by MLAAttention.forward
# before calling forward_mha, see mla_attention.py:511-512)
if impl.dcp_world_size == -1:
impl.dcp_world_size = 1
# Create mock MLA layer
mock_layer = MockMLAAttentionLayer(
impl=impl,
num_heads=num_heads,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_lora_rank=kv_lora_rank,
device=device,
kv_b_proj=mock_kv_b_proj,
)
# Populate static_forward_context with mock attention layers # Populate static_forward_context with mock attention layers
for layer_name in layer_names: for layer_name in layer_names:
vllm_config.compilation_config.static_forward_context[layer_name] = ( vllm_config.compilation_config.static_forward_context[layer_name] = (
MockMLAAttentionLayer(impl) mock_layer
) )
# Build metadata # Build metadata
...@@ -357,18 +597,15 @@ def run_attention_backend( ...@@ -357,18 +597,15 @@ def run_attention_backend(
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
) )
# Create mock layer and output buffer # Create output buffer
mock_layer = MockAttentionLayer(device)
num_tokens = query.shape[0] num_tokens = query.shape[0]
output = torch.empty( output = torch.empty(
num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device
) )
# Run forward pass # Run forward pass
# NOTE: The query, key, and value are already shaped correctly output = mock_layer.forward_impl(
# in the calling test function. query, kv_c, k_pe, kv_cache, attn_metadata, output
output = impl.forward(
mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output
) )
return output return output
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
from tests.v1.attention.test_mla_backends import ( from tests.v1.attention.test_mla_backends import (
BATCH_SPECS, BATCH_SPECS,
BatchSpec, BatchSpec,
MockAttentionLayer, MockSparseMLAAttentionLayer,
create_and_prepopulate_kv_cache, create_and_prepopulate_kv_cache,
) )
from tests.v1.attention.utils import ( from tests.v1.attention.utils import (
...@@ -408,20 +408,31 @@ def test_sparse_backend_decode_correctness( ...@@ -408,20 +408,31 @@ def test_sparse_backend_decode_correctness(
impl.process_weights_after_loading(dtype) impl.process_weights_after_loading(dtype)
layer = MockAttentionLayer(device) # Create mock sparse MLA layer with weight matrices
mock_layer = MockSparseMLAAttentionLayer(
impl=impl,
num_heads=num_heads,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_lora_rank=kv_lora_rank,
device=device,
W_UK=W_UK,
W_UV=W_UV,
)
out_buffer = torch.empty( out_buffer = torch.empty(
metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device
) )
with torch.inference_mode(): with torch.inference_mode():
backend_output = impl.forward( backend_output = mock_layer.forward_impl(
layer,
query_vllm, query_vllm,
kv_c_vllm, kv_c_vllm,
k_pe_vllm, k_pe_vllm,
kv_cache, kv_cache,
metadata, metadata,
output=out_buffer, out_buffer,
) )
assert backend_output.shape == sdpa_reference.shape assert backend_output.shape == sdpa_reference.shape
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -562,7 +562,7 @@ direct_register_custom_op( ...@@ -562,7 +562,7 @@ direct_register_custom_op(
def get_attention_context( def get_attention_context(
layer_name: str, layer_name: str,
) -> tuple[dict | object | None, "Attention | MLAAttention", torch.Tensor]: ) -> tuple[Any, "Attention | MLAAttention", torch.Tensor]:
"""Extract attention context for a given layer. """Extract attention context for a given layer.
This helper function extracts the attention metadata, attention layer This helper function extracts the attention metadata, attention layer
......
...@@ -67,7 +67,7 @@ class AttentionBackend(ABC): ...@@ -67,7 +67,7 @@ class AttentionBackend(ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_impl_cls() -> type["AttentionImpl"]: def get_impl_cls() -> type["AttentionImplBase"]:
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
...@@ -594,7 +594,14 @@ class AttentionLayer(Protocol): ...@@ -594,7 +594,14 @@ class AttentionLayer(Protocol):
) -> torch.Tensor: ... ) -> torch.Tensor: ...
class AttentionImpl(ABC, Generic[T]): class AttentionImplBase(ABC, Generic[T]):
"""Base class for attention implementations.
Contains common attributes and initialization logic shared by both
standard AttentionImpl and MLAAttentionImpl. Does not define a forward
method - subclasses define their own forward interfaces.
"""
# Required attributes that all impls should have # Required attributes that all impls should have
num_heads: int num_heads: int
head_size: int head_size: int
...@@ -662,6 +669,13 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -662,6 +669,13 @@ class AttentionImpl(ABC, Generic[T]):
) )
return self return self
def process_weights_after_loading(self, act_dtype: torch.dtype):
pass
class AttentionImpl(AttentionImplBase[T], Generic[T]):
"""Standard attention implementation with forward method."""
@abstractmethod @abstractmethod
def __init__( def __init__(
self, self,
...@@ -704,11 +718,10 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -704,11 +718,10 @@ class AttentionImpl(ABC, Generic[T]):
""" """
return False return False
def process_weights_after_loading(self, act_dtype: torch.dtype):
pass
class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""MLA attention implementation with forward_mqa and forward_mha methods."""
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@abstractmethod @abstractmethod
def __init__( def __init__(
self, self,
...@@ -731,22 +744,78 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]): ...@@ -731,22 +744,78 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
v_head_dim: int, v_head_dim: int,
kv_b_proj: "ColumnParallelLinear", kv_b_proj: "ColumnParallelLinear",
indexer: object | None = None, indexer: object | None = None,
q_pad_num_heads: int | None = None,
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def forward( def forward_mha(
self, self,
layer: AttentionLayer, q: torch.Tensor,
hidden_states_or_cq: torch.Tensor,
kv_c_normed: torch.Tensor, kv_c_normed: torch.Tensor,
k_pe: torch.Tensor, k_pe: torch.Tensor,
kv_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
output: torch.Tensor | None = None, k_scale: torch.Tensor,
output_scale: torch.Tensor | None = None, output: torch.Tensor,
output_block_scale: torch.Tensor | None = None, ) -> None:
) -> torch.Tensor: """MHA-style prefill forward pass."""
raise NotImplementedError
@abstractmethod
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: T,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""MQA-style decode forward pass."""
raise NotImplementedError
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""Sparse MLA attention implementation with only forward_mqa method.
Sparse MLA implementations only support decode (MQA-style) attention.
They do not support prefill (MHA-style) attention.
"""
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
q_lora_rank: int | None,
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: "ColumnParallelLinear",
indexer: object | None = None,
q_pad_num_heads: int | None = None,
) -> None:
raise NotImplementedError
@abstractmethod
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: T,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""MQA-style decode forward pass."""
raise NotImplementedError raise NotImplementedError
......
...@@ -244,7 +244,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -244,7 +244,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
return out, lse return out, lse
def _forward_decode( def forward_mqa(
self, self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
......
...@@ -293,7 +293,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): ...@@ -293,7 +293,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
"FlashAttnMLA V1 with FP8 KV cache not yet supported" "FlashAttnMLA V1 with FP8 KV cache not yet supported"
) )
def _forward_decode( def forward_mqa(
self, self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
......
...@@ -150,7 +150,7 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -150,7 +150,7 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
self.bmm1_scale: float | None = None self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None self.bmm2_scale: float | None = None
def _forward_decode( def forward_mqa(
self, self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
......
...@@ -234,7 +234,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -234,7 +234,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl" "FlashMLAImpl"
) )
def _forward_decode( def forward_mqa(
self, self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
......
...@@ -11,7 +11,6 @@ from vllm.config import VllmConfig, get_current_vllm_config ...@@ -11,7 +11,6 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import ( from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBaseImpl,
get_mla_dims, get_mla_dims,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -25,6 +24,7 @@ from vllm.v1.attention.backend import ( ...@@ -25,6 +24,7 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
MultipleOf, MultipleOf,
SparseMLAAttentionImpl,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
reshape_attn_output_for_spec_decode, reshape_attn_output_for_spec_decode,
...@@ -686,7 +686,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad ...@@ -686,7 +686,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
return metadata return metadata
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
@staticmethod @staticmethod
def _compute_fp8_decode_padded_heads(num_heads: int) -> int: def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
# FP8 decode kernel only supports h_q = 64 or 128 # FP8 decode kernel only supports h_q = 64 or 128
...@@ -710,19 +710,12 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -710,19 +710,12 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
indexer: "Indexer | None" = None, indexer: "Indexer | None" = None,
**mla_args, **mla_args,
) -> None: ) -> None:
super().__init__( self.num_heads = num_heads
num_heads, self.head_size = head_size
head_size, self.scale = float(scale)
scale, self.num_kv_heads = num_kv_heads
num_kv_heads, self.kv_cache_dtype = kv_cache_dtype
alibi_slopes, self.kv_lora_rank: int = mla_args["kv_lora_rank"]
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
self.softmax_scale = scale self.softmax_scale = scale
assert indexer is not None assert indexer is not None
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
...@@ -974,78 +967,39 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -974,78 +967,39 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
output = output[:, : self.num_heads, :] output = output[:, : self.num_heads, :]
return output return output
def forward( def forward_mqa(
self, self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
layer: AttentionLayer, layer: AttentionLayer,
q: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor | None]:
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: FlashMLASparseMetadata | None,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode # MQA 576/512 approach for both prefill and decode
assert output is not None, "Output tensor must be provided." # Concatenate q if it's a tuple (ql_nope, q_pe)
if isinstance(q, tuple):
if output_scale is not None or output_block_scale is not None: q = torch.cat(q, dim=-1)
raise NotImplementedError(
"fused output quantization is not yet supported for MLACommonImpl"
)
if attn_metadata is None:
# Dummy run - no need to allocate buffers
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs num_actual_toks = q.shape[0]
q = q[:num_actual_toks, ...] # Get topk indices
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
assert self.topk_indices_buffer is not None assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks] topk_indices = self.topk_indices_buffer[:num_actual_toks]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla" use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"
q = torch.cat([ql_nope, q_pe], dim=-1)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
if not use_fp8_cache: if not use_fp8_cache:
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata) attn_out = self._forward_bf16_kv(
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
)
elif attn_metadata.fp8_use_mixed_batch: elif attn_metadata.fp8_use_mixed_batch:
attn_out = self._forward_fp8_kv_mixed_batch( attn_out = self._forward_fp8_kv_mixed_batch(
q, kv_cache, topk_indices, attn_metadata q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
) )
else: else:
attn_out = self._forward_fp8_kv_separate_prefill_decode( attn_out = self._forward_fp8_kv_separate_prefill_decode(
q, kv_cache, topk_indices, attn_metadata q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
) )
self._v_up_proj(attn_out, out=output[:num_actual_toks]) return attn_out, None
return output
...@@ -241,7 +241,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -241,7 +241,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
return output return output
def _forward_decode( def forward_mqa(
self, self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
......
...@@ -7,12 +7,10 @@ from typing import TYPE_CHECKING, ClassVar ...@@ -7,12 +7,10 @@ from typing import TYPE_CHECKING, ClassVar
import numpy as np import numpy as np
import torch import torch
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import ( from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBaseImpl,
get_mla_dims, get_mla_dims,
) )
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
...@@ -23,6 +21,7 @@ from vllm.v1.attention.backend import ( ...@@ -23,6 +21,7 @@ from vllm.v1.attention.backend import (
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
SparseMLAAttentionImpl,
) )
from vllm.v1.attention.backends.mla.flashmla_sparse import ( from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index, triton_convert_req_index_to_global_index,
...@@ -269,7 +268,7 @@ def reference_mla_sparse_prefill( ...@@ -269,7 +268,7 @@ def reference_mla_sparse_prefill(
return (result, lse) return (result, lse)
class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]): class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata]):
def __init__( def __init__(
self, self,
num_heads: int, num_heads: int,
...@@ -287,23 +286,15 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]): ...@@ -287,23 +286,15 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
indexer: "Indexer | None" = None, indexer: "Indexer | None" = None,
**mla_args, **mla_args,
) -> None: ) -> None:
super().__init__( self.num_heads = num_heads
num_heads, self.head_size = head_size
head_size, self.scale = float(scale)
scale, self.num_kv_heads = num_kv_heads
num_kv_heads, self.kv_cache_dtype = kv_cache_dtype
alibi_slopes, self.kv_lora_rank: int = mla_args["kv_lora_rank"]
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
self.softmax_scale = scale self.softmax_scale = scale
assert indexer is not None assert indexer is not None
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
def _forward_bf16_kv( def _forward_bf16_kv(
self, self,
...@@ -342,56 +333,23 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]): ...@@ -342,56 +333,23 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
return output[:, : self.num_heads, :] return output[:, : self.num_heads, :]
def forward( def forward_mqa(
self, self,
layer: AttentionLayer, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: ROCMAiterMLASparseMetadata, attn_metadata: ROCMAiterMLASparseMetadata,
output: torch.Tensor | None = None, layer: AttentionLayer,
output_scale: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]:
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode # MQA 576/512 approach for both prefill and decode
assert output is not None, "Output tensor must be provided." # Concatenate q if it's a tuple (ql_nope, q_pe)
if isinstance(q, tuple):
if output_scale is not None or output_block_scale is not None: q = torch.cat(q, dim=-1)
raise NotImplementedError(
"fused output quantization is not yet supported for ROCMAiterMLASparse"
)
if attn_metadata is None:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
if self.is_fp8bmm_enabled:
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
ql_nope = rocm_aiter_ops.triton_fp8_bmm(
q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)
else:
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
num_actual_toks = q.shape[0]
# Get topk indices
assert self.topk_indices_buffer is not None assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks] topk_indices = self.topk_indices_buffer[:num_actual_toks]
...@@ -403,22 +361,8 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]): ...@@ -403,22 +361,8 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
NUM_TOPK_TOKENS=attn_metadata.topk_tokens, NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
) )
q = torch.cat([ql_nope, q_pe], dim=-1)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
attn_out = self._forward_bf16_kv( attn_out = self._forward_bf16_kv(
q, kv_cache, topk_indices_global, attn_metadata q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata
) )
self._v_up_proj(attn_out, out=output[:num_actual_toks]) return attn_out, None
return output
...@@ -110,7 +110,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -110,7 +110,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
**kwargs, **kwargs,
) )
def _forward_decode( def forward_mqa(
self, self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment