Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aaa901ad
Unverified
Commit
aaa901ad
authored
Jan 30, 2026
by
Matthew Bonanni
Committed by
GitHub
Jan 30, 2026
Browse files
[Attention] Move MLA `forward` from backend to layer (#33284)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
010ec0c3
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
841 additions
and
623 deletions
+841
-623
tests/v1/attention/test_mla_backends.py
tests/v1/attention/test_mla_backends.py
+247
-10
tests/v1/attention/test_sparse_mla_backends.py
tests/v1/attention/test_sparse_mla_backends.py
+16
-5
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+2
-2
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+443
-440
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+82
-13
vllm/v1/attention/backends/mla/cutlass_mla.py
vllm/v1/attention/backends/mla/cutlass_mla.py
+1
-1
vllm/v1/attention/backends/mla/flashattn_mla.py
vllm/v1/attention/backends/mla/flashattn_mla.py
+1
-1
vllm/v1/attention/backends/mla/flashinfer_mla.py
vllm/v1/attention/backends/mla/flashinfer_mla.py
+1
-1
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+1
-1
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+24
-70
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+1
-1
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
+21
-77
vllm/v1/attention/backends/mla/triton_mla.py
vllm/v1/attention/backends/mla/triton_mla.py
+1
-1
No files found.
tests/v1/attention/test_mla_backends.py
View file @
aaa901ad
...
...
@@ -274,11 +274,157 @@ class MockAttentionLayer:
raise
NotImplementedError
class
MockSparseMLAAttentionLayer
:
"""A mock sparse MLA attention layer for testing.
Sparse MLA implementations only support forward_mqa (decode-style attention)
for all tokens, so this class only implements that path.
Unlike regular MLA impls, sparse MLA impls don't have W_UK_T and W_UV
attributes. These transformations are done by the layer (MLAAttention),
not the impl. This mock layer accepts these weight matrices directly.
"""
def
__init__
(
self
,
impl
,
num_heads
:
int
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
v_head_dim
:
int
,
kv_lora_rank
:
int
,
device
:
torch
.
device
,
W_UK
:
torch
.
Tensor
,
W_UV
:
torch
.
Tensor
,
):
self
.
impl
=
impl
self
.
num_heads
=
num_heads
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
kv_lora_rank
=
kv_lora_rank
# Compute weight matrices in the format expected by forward_impl
# W_UK shape: (L, N, P) -> W_UK_T shape: (N, P, L)
self
.
W_UK_T
=
W_UK
.
permute
(
1
,
2
,
0
)
# W_UV shape: (L, N, V) -> (N, L, V)
self
.
W_UV
=
W_UV
.
transpose
(
0
,
1
)
# Scale attributes needed by attention backends
self
.
_q_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_k_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_v_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_prob_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_q_scale_float
=
1.0
self
.
_k_scale_float
=
1.0
self
.
_v_scale_float
=
1.0
def
forward_impl
(
self
,
q
:
torch
.
Tensor
,
kv_c
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
,
output
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Forward for sparse MLA - uses forward_mqa for all tokens."""
# Write to KV cache
kv_cache_dtype
=
getattr
(
self
.
impl
,
"kv_cache_dtype"
,
"auto"
)
if
kv_cache
.
numel
()
>
0
:
ops
.
concat_and_cache_mla
(
kv_c
,
k_pe
.
squeeze
(
1
),
kv_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
=
kv_cache_dtype
,
scale
=
self
.
_k_scale
,
)
num_tokens
=
q
.
shape
[
0
]
# Sparse MLA uses forward_mqa for all tokens
# Split q into nope and pe parts
mqa_q_nope
,
mqa_q_pe
=
q
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# Convert from (B, N, P) to (N, B, P)
mqa_q_nope
=
mqa_q_nope
.
transpose
(
0
,
1
)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
mqa_ql_nope
=
torch
.
bmm
(
mqa_q_nope
,
self
.
W_UK_T
)
# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope
=
mqa_ql_nope
.
transpose
(
0
,
1
)
# Pass as tuple to forward_mqa
mqa_q
=
(
mqa_ql_nope
,
mqa_q_pe
)
attn_out
,
_
=
self
.
impl
.
forward_mqa
(
mqa_q
,
kv_cache
,
attn_metadata
,
self
)
# v_up projection: multiply by W_UV
# attn_out shape: (B, N, L) where L = kv_lora_rank
# W_UV shape: (N, L, V)
# output shape: (B, N, V) -> flatten to (B, N*V)
decode_output
=
torch
.
bmm
(
attn_out
.
transpose
(
0
,
1
),
self
.
W_UV
).
transpose
(
0
,
1
)
output
[:
num_tokens
]
=
decode_output
.
reshape
(
num_tokens
,
self
.
num_heads
*
self
.
v_head_dim
)
return
output
class
MockMLAAttentionLayer
(
AttentionLayerBase
):
"""A mock MLA attention layer for populating static_forward_context."""
"""A mock MLA attention layer for testing.
This replicates the forward_impl logic from MLAAttention to allow
testing MLA backends without the full layer infrastructure.
The W_UK_T and W_UV weight matrices are created on the layer (like in
MLAAttention.process_weights_after_loading), not on the impl.
"""
def
__init__
(
self
,
impl
):
def
__init__
(
self
,
impl
,
num_heads
:
int
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
v_head_dim
:
int
,
kv_lora_rank
:
int
,
device
:
torch
.
device
,
kv_b_proj
,
):
self
.
impl
=
impl
self
.
num_heads
=
num_heads
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
kv_lora_rank
=
kv_lora_rank
# Compute weight matrices from kv_b_proj (like MLAAttention does)
# This replicates MLAAttention.process_weights_after_loading logic
kv_b_proj_weight
=
kv_b_proj
.
weight
.
T
kv_b_proj_weight
=
kv_b_proj_weight
.
view
(
kv_lora_rank
,
num_heads
,
qk_nope_head_dim
+
v_head_dim
,
)
W_UK
,
W_UV
=
kv_b_proj_weight
.
split
([
qk_nope_head_dim
,
v_head_dim
],
dim
=-
1
)
# Convert from (L, N, V) to (N, L, V)
self
.
W_UV
=
W_UV
.
transpose
(
0
,
1
)
# Convert from (L, N, P) to (N, P, L)
self
.
W_UK_T
=
W_UK
.
permute
(
1
,
2
,
0
)
# Scale attributes needed by attention backends
self
.
_q_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_k_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_v_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_prob_scale
=
torch
.
tensor
(
1.0
,
device
=
device
)
self
.
_q_scale_float
=
1.0
self
.
_k_scale_float
=
1.0
self
.
_v_scale_float
=
1.0
def
get_attn_backend
(
self
):
raise
NotImplementedError
...
...
@@ -286,6 +432,83 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def
get_kv_cache_spec
(
self
,
vllm_config
):
raise
NotImplementedError
def
forward_impl
(
self
,
q
:
torch
.
Tensor
,
kv_c
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
,
output
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Replicates MLAAttention.forward_impl logic for testing."""
# Write to KV cache
if
kv_cache
.
numel
()
>
0
:
ops
.
concat_and_cache_mla
(
kv_c
,
k_pe
.
squeeze
(
1
),
kv_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
=
"auto"
,
scale
=
self
.
_k_scale
,
)
# Determine decode vs prefill split
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
or
0
has_decode
=
(
attn_metadata
.
num_decodes
or
0
)
>
0
has_prefill
=
(
attn_metadata
.
num_prefills
or
0
)
>
0
# Run prefill with forward_mha
if
has_prefill
:
prefill_q
=
q
[
num_decode_tokens
:]
prefill_k_pe
=
k_pe
[
num_decode_tokens
:]
prefill_k_c
=
kv_c
[
num_decode_tokens
:]
self
.
impl
.
forward_mha
(
prefill_q
,
prefill_k_c
,
prefill_k_pe
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
output
=
output
[
num_decode_tokens
:],
)
# Run decode with forward_mqa
if
has_decode
:
decode_q
=
q
[:
num_decode_tokens
]
# Split q into nope and pe parts
mqa_q_nope
,
mqa_q_pe
=
decode_q
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# Convert from (B, N, P) to (N, B, P)
mqa_q_nope
=
mqa_q_nope
.
transpose
(
0
,
1
)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
mqa_ql_nope
=
torch
.
bmm
(
mqa_q_nope
,
self
.
W_UK_T
)
# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope
=
mqa_ql_nope
.
transpose
(
0
,
1
)
# Pass as tuple to forward_mqa
mqa_q
=
(
mqa_ql_nope
,
mqa_q_pe
)
attn_out
,
_
=
self
.
impl
.
forward_mqa
(
mqa_q
,
kv_cache
,
attn_metadata
,
self
)
# v_up projection: multiply by W_UV
# attn_out shape: (B, N, L) where L = kv_lora_rank
# W_UV shape: (N, L, V)
# output shape: (B, N, V) -> flatten to (B, N*V)
decode_output
=
torch
.
bmm
(
attn_out
.
transpose
(
0
,
1
),
self
.
W_UV
).
transpose
(
0
,
1
)
output
[:
num_decode_tokens
]
=
decode_output
.
reshape
(
num_decode_tokens
,
self
.
num_heads
*
self
.
v_head_dim
)
return
output
def
run_attention_backend
(
backend
:
AttentionBackendEnum
,
...
...
@@ -340,14 +563,31 @@ def run_attention_backend(
kv_b_proj
=
mock_kv_b_proj
,
)
# Process weights
to create W_UK_T and W_UV attributes needed by MLA
# Process weights
on the impl
act_dtype
=
_convert_dtype_to_torch
(
vllm_config
.
model_config
.
dtype
)
impl
.
process_weights_after_loading
(
act_dtype
)
# Initialize DCP attributes (normally set by MLAAttention.forward
# before calling forward_mha, see mla_attention.py:511-512)
if
impl
.
dcp_world_size
==
-
1
:
impl
.
dcp_world_size
=
1
# Create mock MLA layer
mock_layer
=
MockMLAAttentionLayer
(
impl
=
impl
,
num_heads
=
num_heads
,
qk_nope_head_dim
=
qk_nope_head_dim
,
qk_rope_head_dim
=
qk_rope_head_dim
,
v_head_dim
=
v_head_dim
,
kv_lora_rank
=
kv_lora_rank
,
device
=
device
,
kv_b_proj
=
mock_kv_b_proj
,
)
# Populate static_forward_context with mock attention layers
for
layer_name
in
layer_names
:
vllm_config
.
compilation_config
.
static_forward_context
[
layer_name
]
=
(
M
ock
MLAAttentionLayer
(
impl
)
m
ock
_layer
)
# Build metadata
...
...
@@ -357,18 +597,15 @@ def run_attention_backend(
common_attn_metadata
=
common_attn_metadata
,
)
# Create mock layer and output buffer
mock_layer
=
MockAttentionLayer
(
device
)
# Create output buffer
num_tokens
=
query
.
shape
[
0
]
output
=
torch
.
empty
(
num_tokens
,
num_heads
*
v_head_dim
,
dtype
=
query
.
dtype
,
device
=
query
.
device
)
# Run forward pass
# NOTE: The query, key, and value are already shaped correctly
# in the calling test function.
output
=
impl
.
forward
(
mock_layer
,
query
,
kv_c
,
k_pe
,
kv_cache
,
attn_metadata
,
output
=
output
output
=
mock_layer
.
forward_impl
(
query
,
kv_c
,
k_pe
,
kv_cache
,
attn_metadata
,
output
)
return
output
...
...
tests/v1/attention/test_sparse_mla_backends.py
View file @
aaa901ad
...
...
@@ -12,7 +12,7 @@ import torch
from
tests.v1.attention.test_mla_backends
import
(
BATCH_SPECS
,
BatchSpec
,
MockAttentionLayer
,
Mock
SparseMLA
AttentionLayer
,
create_and_prepopulate_kv_cache
,
)
from
tests.v1.attention.utils
import
(
...
...
@@ -408,20 +408,31 @@ def test_sparse_backend_decode_correctness(
impl
.
process_weights_after_loading
(
dtype
)
layer
=
MockAttentionLayer
(
device
)
# Create mock sparse MLA layer with weight matrices
mock_layer
=
MockSparseMLAAttentionLayer
(
impl
=
impl
,
num_heads
=
num_heads
,
qk_nope_head_dim
=
qk_nope_head_dim
,
qk_rope_head_dim
=
qk_rope_head_dim
,
v_head_dim
=
v_head_dim
,
kv_lora_rank
=
kv_lora_rank
,
device
=
device
,
W_UK
=
W_UK
,
W_UV
=
W_UV
,
)
out_buffer
=
torch
.
empty
(
metadata
.
num_actual_tokens
,
num_heads
*
v_head_dim
,
dtype
=
dtype
,
device
=
device
)
with
torch
.
inference_mode
():
backend_output
=
impl
.
forward
(
layer
,
backend_output
=
mock_layer
.
forward_impl
(
query_vllm
,
kv_c_vllm
,
k_pe_vllm
,
kv_cache
,
metadata
,
output
=
out_buffer
,
out_buffer
,
)
assert
backend_output
.
shape
==
sdpa_reference
.
shape
...
...
vllm/model_executor/layers/attention/attention.py
View file @
aaa901ad
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Any
import
torch
import
torch.nn
as
nn
...
...
@@ -562,7 +562,7 @@ direct_register_custom_op(
def
get_attention_context
(
layer_name
:
str
,
)
->
tuple
[
dict
|
object
|
None
,
"Attention | MLAAttention"
,
torch
.
Tensor
]:
)
->
tuple
[
Any
,
"Attention | MLAAttention"
,
torch
.
Tensor
]:
"""Extract attention context for a given layer.
This helper function extracts the attention metadata, attention layer
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
aaa901ad
...
...
@@ -63,7 +63,7 @@ W_UV project kv_c to v shape [Lkv, N, V]
W_O project v to h_t shape [N * V, H]
## Compute Friendly Approach (i.e. "
_
forward_
prefill
"):
## Compute Friendly Approach (i.e. "forward_
mha
"):
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(Sq, N, P)
...
...
@@ -91,7 +91,7 @@ NOTE: in the actual code,
`out_proj` is W_O
## Data-Movement Friendly Approach (i.e. "
_
forward_
decode
"):
## Data-Movement Friendly Approach (i.e. "forward_
mqa
"):
Runtime
q_c = h_t @ W_DQ
...
...
@@ -243,6 +243,7 @@ from vllm.v1.attention.backend import (
AttentionType
,
CommonAttentionMetadata
,
MLAAttentionImpl
,
SparseMLAAttentionImpl
,
)
from
vllm.v1.attention.backends.fa_utils
import
get_flash_attn_version
from
vllm.v1.attention.backends.utils
import
(
...
...
@@ -266,6 +267,9 @@ logger = init_logger(__name__)
class
MLAAttention
(
nn
.
Module
,
AttentionLayerBase
):
"""Multi-Head Latent Attention layer.
NOTE: Please read the comment at the top of the file before trying to
understand this class
This class takes query, and compressed key/value tensors as input.
The class does the following:
...
...
@@ -289,6 +293,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
prefix
:
str
=
""
,
use_sparse
:
bool
=
False
,
indexer
:
object
|
None
=
None
,
q_pad_num_heads
:
int
|
None
=
None
,
**
extra_impl_args
,
):
super
().
__init__
()
...
...
@@ -299,8 +304,14 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self
.
v_head_dim
=
v_head_dim
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
kv_b_proj
=
kv_b_proj
self
.
head_size
=
kv_lora_rank
+
qk_rope_head_dim
self
.
layer_name
=
prefix
self
.
indexer
=
indexer
self
.
q_pad_num_heads
=
q_pad_num_heads
self
.
num_kv_heads
=
1
self
.
qk_head_dim
=
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
...
...
@@ -364,6 +375,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
v_head_dim
=
self
.
v_head_dim
,
kv_b_proj
=
kv_b_proj
,
indexer
=
indexer
,
q_pad_num_heads
=
q_pad_num_heads
,
**
extra_impl_args
,
)
...
...
@@ -388,6 +400,26 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self
.
k_range
=
torch
.
tensor
(
envs
.
K_SCALE_CONSTANT
,
dtype
=
torch
.
float32
)
self
.
v_range
=
torch
.
tensor
(
envs
.
V_SCALE_CONSTANT
,
dtype
=
torch
.
float32
)
self
.
is_aiter_triton_fp8_bmm_enabled
=
rocm_aiter_ops
.
is_fp8bmm_enabled
()
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
self
.
is_aiter_triton_fp4_bmm_enabled
=
(
rocm_aiter_ops
.
is_fp4bmm_enabled
()
and
self
.
kv_b_proj
.
weight
.
dtype
==
torch
.
bfloat16
)
# Attributes for forward_impl method
self
.
chunked_prefill_workspace_size
=
(
MLACommonMetadataBuilder
.
determine_chunked_prefill_workspace_size
(
get_current_vllm_config
()
)
)
self
.
_decode_concat_quant_fp8_op
=
_DecodeConcatQuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
compile_native
=
True
,
)
def
forward
(
self
,
q
:
torch
.
Tensor
,
...
...
@@ -407,8 +439,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if
self
.
attn_backend
.
accept_output_buffer
:
output
=
torch
.
empty
(
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
self
.
impl
.
forward
(
self
,
self
.
forward_impl
(
q
,
kv_c_normed
,
k_pe
,
...
...
@@ -418,8 +449,8 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
return
output
else
:
return
self
.
impl
.
forward
(
self
,
q
,
kv_c_normed
,
k_pe
,
self_kv_cache
,
attn_metadata
return
self
.
forward
_impl
(
q
,
kv_c_normed
,
k_pe
,
self_kv_cache
,
attn_metadata
)
else
:
if
self
.
attn_backend
.
accept_output_buffer
:
...
...
@@ -440,9 +471,282 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self
.
layer_name
,
)
def
forward_impl
(
self
,
q
:
torch
.
Tensor
,
k_c_normed
:
torch
.
Tensor
,
# key in unified attn
k_pe
:
torch
.
Tensor
,
# value in unified attn
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
"MLACommonMetadata"
,
output
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported for MLA"
)
if
attn_metadata
is
None
:
# During the profile run try to simulate to worse case output size
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
# since this can be large
_
=
torch
.
empty
(
(
self
.
chunked_prefill_workspace_size
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
,
),
device
=
k_c_normed
.
device
,
dtype
=
k_c_normed
.
dtype
,
)
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return
output
.
fill_
(
0
)
if
self
.
impl
.
dcp_world_size
==
-
1
:
self
.
impl
.
dcp_world_size
=
get_dcp_group
().
world_size
fp8_attention
=
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
num_actual_toks
=
attn_metadata
.
num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded
=
output
output
=
output
[:
num_actual_toks
,
...]
q
=
q
[:
num_actual_toks
,
...]
k_c_normed
=
k_c_normed
[:
num_actual_toks
,
...]
k_pe
=
k_pe
[:
num_actual_toks
,
...]
assert
(
attn_metadata
.
num_decodes
is
not
None
and
attn_metadata
.
num_prefills
is
not
None
and
attn_metadata
.
num_decode_tokens
is
not
None
)
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
decode_q
=
q
[:
num_decode_tokens
]
prefill_q
=
q
[
num_decode_tokens
:]
prefill_k_pe
=
k_pe
[
num_decode_tokens
:]
prefill_k_c_normed
=
k_c_normed
[
num_decode_tokens
:]
# write the latent and rope to kv cache
if
kv_cache
.
numel
()
>
0
:
ops
.
concat_and_cache_mla
(
k_c_normed
,
k_pe
.
squeeze
(
1
),
kv_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
=
self
.
kv_cache_dtype
,
scale
=
self
.
_k_scale
,
)
if
fp8_attention
:
kv_cache
=
kv_cache
.
view
(
current_platform
.
fp8_dtype
())
# Sparse MLA impls only support forward_mqa (decode-style attention)
is_sparse_impl
=
isinstance
(
self
.
impl
,
SparseMLAAttentionImpl
)
if
has_prefill
and
not
is_sparse_impl
:
self
.
impl
.
forward_mha
(
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
output
=
output
[
num_decode_tokens
:],
)
if
has_decode
or
(
has_prefill
and
is_sparse_impl
):
# For sparse impl, we always use forward_mqa for all tokens
# For non-sparse impl, we only use forward_mqa for decode tokens
if
is_sparse_impl
:
mqa_q
=
q
mqa_output_slice
=
output
else
:
assert
attn_metadata
.
decode
is
not
None
mqa_q
=
decode_q
mqa_output_slice
=
output
[:
num_decode_tokens
]
mqa_q_nope
,
mqa_q_pe
=
mqa_q
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# Convert from (B, N, P) to (N, B, P)
mqa_q_nope
=
mqa_q_nope
.
transpose
(
0
,
1
)
if
self
.
q_pad_num_heads
is
not
None
:
B
,
N
,
L
=
mqa_q_pe
.
shape
mqa_pe_padded
=
mqa_q_pe
.
new_empty
((
B
,
self
.
q_pad_num_heads
,
L
))
mqa_pe_padded
.
resize_
((
B
,
N
,
L
))
mqa_pe_padded
.
copy_
(
mqa_q_pe
)
mqa_q_pe
=
mqa_pe_padded
if
self
.
is_aiter_triton_fp4_bmm_enabled
:
from
aiter.ops.triton.batched_gemm_a16wfp4
import
batched_gemm_a16wfp4
mqa_ql_nope
=
batched_gemm_a16wfp4
(
mqa_q_nope
,
self
.
W_K
,
self
.
W_K_scale
,
transpose_bm
=
True
,
prequant
=
True
,
y_scale
=
self
.
_q_scale
if
fp8_attention
else
None
,
)
elif
self
.
is_aiter_triton_fp8_bmm_enabled
:
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
mqa_ql_nope
=
rocm_aiter_ops
.
triton_fp8_bmm
(
mqa_q_nope
,
self
.
W_K
,
self
.
W_K_scale
,
group_size
=
128
,
transpose_bm
=
True
,
)
else
:
# Pads the head_dim if necessary (for the underlying kernel)
N
,
B
,
P
=
mqa_q_nope
.
shape
_
,
_
,
L
=
self
.
W_UK_T
.
shape
if
self
.
q_pad_num_heads
is
not
None
:
mqa_ql_nope
=
mqa_q_nope
.
new_empty
((
self
.
q_pad_num_heads
,
B
,
L
))
mqa_ql_nope
.
resize_
((
N
,
B
,
L
))
else
:
mqa_ql_nope
=
mqa_q_nope
.
new_empty
((
N
,
B
,
L
))
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
torch
.
bmm
(
mqa_q_nope
,
self
.
W_UK_T
,
out
=
mqa_ql_nope
)
# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope
=
mqa_ql_nope
.
transpose
(
0
,
1
)
if
fp8_attention
:
assert
mqa_ql_nope
.
shape
[
0
]
==
mqa_q_pe
.
shape
[
0
]
assert
mqa_ql_nope
.
shape
[
1
]
==
mqa_q_pe
.
shape
[
1
]
mqa_q
=
self
.
_decode_concat_quant_fp8_op
(
mqa_ql_nope
,
mqa_q_pe
,
self
.
_q_scale
)
else
:
mqa_q
=
(
mqa_ql_nope
,
mqa_q_pe
)
if
self
.
impl
.
dcp_world_size
>
1
:
assert
not
fp8_attention
,
"DCP not support fp8 kvcache now."
# concatenate mqa_ql_nope and mqa_q_pe -> (B, N, L + P)
mqa_q
=
torch
.
cat
(
mqa_q
,
dim
=-
1
)
# mqa_q do allgather in head dim.
mqa_q
=
get_dcp_group
().
all_gather
(
mqa_q
,
dim
=
1
)
# call decode attn
attn_out
,
lse
=
self
.
impl
.
forward_mqa
(
mqa_q
,
kv_cache
,
attn_metadata
,
self
)
# correct dcp attn_out with lse.
if
self
.
impl
.
dcp_world_size
>
1
:
attn_out
=
cp_lse_ag_out_rs
(
attn_out
,
lse
,
get_dcp_group
(),
is_lse_base_on_e
=
not
getattr
(
self
,
"_use_fi_prefill"
,
False
),
)
# v_up projection
self
.
_v_up_proj
(
attn_out
,
out
=
mqa_output_slice
)
return
output_padded
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
if
hasattr
(
self
.
impl
,
"process_weights_after_loading"
):
self
.
impl
.
process_weights_after_loading
(
act_dtype
)
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
,
out_dtype
=
act_dtype
).
T
assert
kv_b_proj_weight
.
shape
==
(
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
qk_nope_head_dim
+
self
.
v_head_dim
),
),
(
f
"
{
kv_b_proj_weight
.
shape
=
}
, "
f
"
{
self
.
kv_lora_rank
=
}
, "
f
"
{
self
.
num_heads
=
}
, "
f
"
{
self
.
qk_nope_head_dim
=
}
, "
f
"
{
self
.
v_head_dim
=
}
"
)
kv_b_proj_weight
=
kv_b_proj_weight
.
view
(
self
.
kv_lora_rank
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
,
)
W_UK
,
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
if
self
.
is_aiter_triton_fp4_bmm_enabled
:
from
vllm.model_executor.layers.quantization.quark.utils
import
(
quark_quantize_weight_to_mxfp4
,
)
self
.
W_K
,
self
.
W_K_scale
=
quark_quantize_weight_to_mxfp4
(
W_UK
)
# Convert from (L, N, P) to (N, L, P)
self
.
W_K
=
self
.
W_K
.
transpose
(
0
,
1
)
self
.
W_K_scale
=
self
.
W_K_scale
.
transpose
(
0
,
1
)
self
.
W_V
,
self
.
W_V_scale
=
quark_quantize_weight_to_mxfp4
(
W_UV
.
permute
(
1
,
2
,
0
)
)
elif
self
.
is_aiter_triton_fp8_bmm_enabled
:
W_K
=
W_UK
.
transpose
(
0
,
1
)
# 16 512 128
W_V
=
W_UV
.
permute
(
1
,
2
,
0
)
# 16 128 512
self
.
W_K
,
self
.
W_K_scale
=
dynamic_per_batched_tensor_quant
(
W_K
,
dtype
=
current_platform
.
fp8_dtype
()
)
self
.
W_V
,
self
.
W_V_scale
=
dynamic_per_batched_tensor_quant
(
W_V
,
dtype
=
current_platform
.
fp8_dtype
()
)
# The kernel operates on non-padded inputs. Hence, pre-compiling
# triton kernel to avoid runtime compilation for unseen batch sizes
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
# On DS-R1, this step adds roughly 50s to the model loading time.
max_batch_size
=
1024
# [ToDo] Find the optimal upper limit
pre_compilation_list
=
list
(
range
(
1
,
max_batch_size
+
1
))
if
is_global_first_rank
():
pre_compilation_list
=
tqdm
(
pre_compilation_list
,
desc
=
"[Aiter Triton] Pre-compiling fp8 BMM kernel"
,
total
=
max_batch_size
,
)
for
m
in
pre_compilation_list
:
x
=
torch
.
empty
(
(
self
.
W_K
.
shape
[
0
],
m
,
self
.
W_K
.
shape
[
2
]),
dtype
=
torch
.
bfloat16
,
device
=
self
.
W_K
.
device
,
)
rocm_aiter_ops
.
triton_fp8_bmm
(
x
,
self
.
W_K
,
self
.
W_K_scale
,
group_size
=
128
,
transpose_bm
=
True
)
x
=
torch
.
empty
(
(
self
.
W_V
.
shape
[
0
],
m
,
self
.
W_V
.
shape
[
2
]),
dtype
=
torch
.
bfloat16
,
device
=
self
.
W_V
.
device
,
)
rocm_aiter_ops
.
triton_fp8_bmm
(
x
,
self
.
W_V
,
self
.
W_V_scale
,
group_size
=
128
,
transpose_bm
=
True
)
else
:
# Convert from (L, N, V) to (N, L, V)
self
.
W_UV
=
W_UV
.
transpose
(
0
,
1
)
# Convert from (L, N, P) to (N, P, L)
self
.
W_UK_T
=
W_UK
.
permute
(
1
,
2
,
0
)
# If we should not load quant weights, we initialize the scales to 1.0
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
...
...
@@ -492,6 +796,41 @@ class MLAAttention(nn.Module, AttentionLayerBase):
cache_dtype_str
=
vllm_config
.
cache_config
.
cache_dtype
,
)
def
_v_up_proj
(
self
,
x
:
torch
.
Tensor
,
out
:
torch
.
Tensor
):
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
out
=
out
.
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
)
if
self
.
is_aiter_triton_fp4_bmm_enabled
:
out
=
rocm_aiter_ops
.
batched_gemm_a16wfp4
(
x
,
self
.
W_V
,
self
.
W_V_scale
,
out
,
transpose_bm
=
True
,
prequant
=
True
,
y_scale
=
None
,
)
x
=
out
.
view
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
)
elif
self
.
is_aiter_triton_fp8_bmm_enabled
:
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x
=
rocm_aiter_ops
.
triton_fp8_bmm
(
x
,
self
.
W_V
,
self
.
W_V_scale
,
group_size
=
128
,
transpose_bm
=
True
,
YQ
=
out
)
else
:
# Convert from (B, N * V) to (N, B, V)
out
=
out
.
transpose
(
0
,
1
)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
torch
.
bmm
(
x
,
self
.
W_UV
,
out
=
out
)
# Reuse "out" to make it "hot"
# Convert from (N, B, V) to (B, N * V)
out_new
=
out
.
transpose
(
0
,
1
).
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
)
# Adjust output buffer shape back to the original (B, N * V)
N
,
B
,
V
=
out
.
shape
out
.
resize_
((
B
,
N
*
V
))
out
.
copy_
(
out_new
)
# Copy result
@
maybe_transfer_kv_layer
def
unified_mla_attention
(
...
...
@@ -500,8 +839,8 @@ def unified_mla_attention(
k_pe
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
attn_metadata
,
self
,
kv_cache
=
get_attention_context
(
layer_name
)
output
=
self
.
impl
.
forward
(
self
,
q
,
kv_c_normed
,
k_pe
,
kv_cache
,
attn_metadata
)
attn_metadata
,
layer
,
kv_cache
=
get_attention_context
(
layer_name
)
output
=
layer
.
forward
_impl
(
q
,
kv_c_normed
,
k_pe
,
kv_cache
,
attn_metadata
)
return
output
...
...
@@ -534,9 +873,8 @@ def unified_mla_attention_with_output(
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
attn_metadata
,
self
,
kv_cache
=
get_attention_context
(
layer_name
)
self
.
impl
.
forward
(
self
,
attn_metadata
,
layer
,
kv_cache
=
get_attention_context
(
layer_name
)
layer
.
forward_impl
(
q
,
kv_c_normed
,
k_pe
,
...
...
@@ -1461,246 +1799,103 @@ def reorg_kvcache(
under current CP rank.
local_context_lens_allranks: local context lengths on each CP rank.
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: the local padded max context chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments
=
[]
k_pe_segments
=
[]
src_token_idx
=
0
max_seq_len_check
=
0
for
padded_local_chunk_seq_len
,
local_context_lens
in
zip
(
padded_local_chunk_seq_lens_lst
,
local_context_lens_allranks
):
cur_seq_len
=
0
for
rank
,
local_context_len
in
enumerate
(
local_context_lens
):
# Note(qcs): We split the context into multiple chunks,
# depending on the size of the workspace.
# local_context in dcp0: |-----------------|
# local_context in dcp1: |--------------|
# n*padded_local_chunk: |-----|-----|-----|
# local_chunk_len in dcp1: |-----|-----|--|
# so we need update the last chunk length in dcp1.
local_chunk_len
=
min
(
max
(
0
,
local_context_len
-
chunk_idx
*
chunk_size
),
padded_local_chunk_seq_len
,
)
if
local_chunk_len
!=
0
:
kv_c_segment
=
allgatered_kv_c_normed
[
rank
*
toks
+
src_token_idx
:
rank
*
toks
+
src_token_idx
+
local_chunk_len
]
k_pe_segment
=
allgatered_k_pe
[
rank
*
toks
+
src_token_idx
:
rank
*
toks
+
src_token_idx
+
local_chunk_len
]
kv_c_segments
.
append
(
kv_c_segment
)
k_pe_segments
.
append
(
k_pe_segment
)
cur_seq_len
+=
local_chunk_len
max_seq_len_check
=
max
(
max_seq_len_check
,
cur_seq_len
)
src_token_idx
+=
padded_local_chunk_seq_len
reorganized_kv_c_normed
=
torch
.
cat
(
kv_c_segments
,
dim
=
0
)
reorganized_k_pe
=
torch
.
cat
(
k_pe_segments
,
dim
=
0
)
assert
reorganized_kv_c_normed
.
shape
[
0
]
==
sum_seq_len
assert
reorganized_k_pe
.
shape
[
0
]
==
sum_seq_len
assert
max_seq_len_check
==
max_seq_len
return
reorganized_kv_c_normed
,
reorganized_k_pe
# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl,
# and MLACommonImpl -> MLACommonDenseImpl or somthing like that
class
MLACommonBaseImpl
(
MLAAttentionImpl
[
A
],
Generic
[
A
]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
list
[
float
]
|
None
,
sliding_window
:
int
|
None
,
kv_cache_dtype
:
str
,
logits_soft_cap
:
float
|
None
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
str
|
None
,
# MLA Specific Arguments
q_lora_rank
:
int
|
None
,
kv_lora_rank
:
int
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
qk_head_dim
:
int
,
v_head_dim
:
int
,
kv_b_proj
:
ColumnParallelLinear
,
indexer
=
None
,
q_pad_num_heads
:
int
|
None
=
None
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported for MLA"
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_head_dim
=
qk_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
kv_b_proj
=
kv_b_proj
self
.
indexer
=
indexer
self
.
q_pad_num_heads
=
q_pad_num_heads
self
.
is_aiter_triton_fp8_bmm_enabled
=
rocm_aiter_ops
.
is_fp8bmm_enabled
()
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
self
.
is_aiter_triton_fp4_bmm_enabled
=
(
rocm_aiter_ops
.
is_fp4bmm_enabled
()
and
self
.
kv_b_proj
.
weight
.
dtype
==
torch
.
bfloat16
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
,
out_dtype
=
act_dtype
).
T
assert
kv_b_proj_weight
.
shape
==
(
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
qk_nope_head_dim
+
self
.
v_head_dim
),
),
(
f
"
{
kv_b_proj_weight
.
shape
=
}
, "
f
"
{
self
.
kv_lora_rank
=
}
, "
f
"
{
self
.
num_heads
=
}
, "
f
"
{
self
.
qk_nope_head_dim
=
}
, "
f
"
{
self
.
v_head_dim
=
}
"
)
kv_b_proj_weight
=
kv_b_proj_weight
.
view
(
self
.
kv_lora_rank
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
,
)
W_UK
,
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
if
self
.
is_aiter_triton_fp4_bmm_enabled
:
from
vllm.model_executor.layers.quantization.quark.utils
import
(
quark_quantize_weight_to_mxfp4
,
)
self
.
W_K
,
self
.
W_K_scale
=
quark_quantize_weight_to_mxfp4
(
W_UK
)
# Convert from (L, N, P) to (N, L, P)
self
.
W_K
=
self
.
W_K
.
transpose
(
0
,
1
)
self
.
W_K_scale
=
self
.
W_K_scale
.
transpose
(
0
,
1
)
self
.
W_V
,
self
.
W_V_scale
=
quark_quantize_weight_to_mxfp4
(
W_UV
.
permute
(
1
,
2
,
0
)
)
elif
self
.
is_aiter_triton_fp8_bmm_enabled
:
W_K
=
W_UK
.
transpose
(
0
,
1
)
# 16 512 128
W_V
=
W_UV
.
permute
(
1
,
2
,
0
)
# 16 128 512
self
.
W_K
,
self
.
W_K_scale
=
dynamic_per_batched_tensor_quant
(
W_K
,
dtype
=
current_platform
.
fp8_dtype
()
)
self
.
W_V
,
self
.
W_V_scale
=
dynamic_per_batched_tensor_quant
(
W_V
,
dtype
=
current_platform
.
fp8_dtype
()
)
# The kernel operates on non-padded inputs. Hence, pre-compiling
# triton kernel to avoid runtime compilation for unseen batch sizes
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
# On DS-R1, this step adds roughly 50s to the model loading time.
max_batch_size
=
1024
# [ToDo] Find the optimal upper limit
pre_compilation_list
=
list
(
range
(
1
,
max_batch_size
+
1
))
if
is_global_first_rank
():
pre_compilation_list
=
tqdm
(
pre_compilation_list
,
desc
=
"[Aiter Triton] Pre-compiling fp8 BMM kernel"
,
total
=
max_batch_size
,
)
for
m
in
pre_compilation_list
:
x
=
torch
.
empty
(
(
self
.
W_K
.
shape
[
0
],
m
,
self
.
W_K
.
shape
[
2
]),
dtype
=
torch
.
bfloat16
,
device
=
self
.
W_K
.
device
,
)
rocm_aiter_ops
.
triton_fp8_bmm
(
x
,
self
.
W_K
,
self
.
W_K_scale
,
group_size
=
128
,
transpose_bm
=
True
)
x
=
torch
.
empty
(
(
self
.
W_V
.
shape
[
0
],
m
,
self
.
W_V
.
shape
[
2
]),
dtype
=
torch
.
bfloat16
,
device
=
self
.
W_V
.
device
,
)
rocm_aiter_ops
.
triton_fp8_bmm
(
x
,
self
.
W_V
,
self
.
W_V_scale
,
group_size
=
128
,
transpose_bm
=
True
)
else
:
# Convert from (L, N, V) to (N, L, V)
self
.
W_UV
=
W_UV
.
transpose
(
0
,
1
)
# Convert from (L, N, P) to (N, P, L)
self
.
W_UK_T
=
W_UK
.
permute
(
1
,
2
,
0
)
def
_v_up_proj
(
self
,
x
:
torch
.
Tensor
,
out
:
torch
.
Tensor
):
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
out
=
out
.
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
)
if
self
.
is_aiter_triton_fp4_bmm_enabled
:
out
=
rocm_aiter_ops
.
batched_gemm_a16wfp4
(
x
,
self
.
W_V
,
self
.
W_V_scale
,
out
,
transpose_bm
=
True
,
prequant
=
True
,
y_scale
=
None
,
)
x
=
out
.
view
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
)
elif
self
.
is_aiter_triton_fp8_bmm_enabled
:
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x
=
rocm_aiter_ops
.
triton_fp8_bmm
(
x
,
self
.
W_V
,
self
.
W_V_scale
,
group_size
=
128
,
transpose_bm
=
True
,
YQ
=
out
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: the local padded max context chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments
=
[]
k_pe_segments
=
[]
src_token_idx
=
0
max_seq_len_check
=
0
for
padded_local_chunk_seq_len
,
local_context_lens
in
zip
(
padded_local_chunk_seq_lens_lst
,
local_context_lens_allranks
):
cur_seq_len
=
0
for
rank
,
local_context_len
in
enumerate
(
local_context_lens
):
# Note(qcs): We split the context into multiple chunks,
# depending on the size of the workspace.
# local_context in dcp0: |-----------------|
# local_context in dcp1: |--------------|
# n*padded_local_chunk: |-----|-----|-----|
# local_chunk_len in dcp1: |-----|-----|--|
# so we need update the last chunk length in dcp1.
local_chunk_len
=
min
(
max
(
0
,
local_context_len
-
chunk_idx
*
chunk_size
),
padded_local_chunk_seq_len
,
)
else
:
# Convert from (B, N * V) to (N, B, V)
out
=
out
.
transpose
(
0
,
1
)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
torch
.
bmm
(
x
,
self
.
W_UV
,
out
=
out
)
# Reuse "out" to make it "hot"
# Convert from (N, B, V) to (B, N * V)
out_new
=
out
.
transpose
(
0
,
1
).
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
)
# Adjust output buffer shape back to the original (B, N * V)
N
,
B
,
V
=
out
.
shape
out
.
resize_
((
B
,
N
*
V
))
out
.
copy_
(
out_new
)
# Copy result
if
local_chunk_len
!=
0
:
kv_c_segment
=
allgatered_kv_c_normed
[
rank
*
toks
+
src_token_idx
:
rank
*
toks
+
src_token_idx
+
local_chunk_len
]
k_pe_segment
=
allgatered_k_pe
[
rank
*
toks
+
src_token_idx
:
rank
*
toks
+
src_token_idx
+
local_chunk_len
]
kv_c_segments
.
append
(
kv_c_segment
)
k_pe_segments
.
append
(
k_pe_segment
)
cur_seq_len
+=
local_chunk_len
max_seq_len_check
=
max
(
max_seq_len_check
,
cur_seq_len
)
src_token_idx
+=
padded_local_chunk_seq_len
reorganized_kv_c_normed
=
torch
.
cat
(
kv_c_segments
,
dim
=
0
)
reorganized_k_pe
=
torch
.
cat
(
k_pe_segments
,
dim
=
0
)
assert
reorganized_kv_c_normed
.
shape
[
0
]
==
sum_seq_len
assert
reorganized_k_pe
.
shape
[
0
]
==
sum_seq_len
assert
max_seq_len_check
==
max_seq_len
return
reorganized_kv_c_normed
,
reorganized_k_pe
class
MLACommonImpl
(
MLA
CommonBase
Impl
[
M
],
Generic
[
M
]):
class
MLACommonImpl
(
MLA
Attention
Impl
[
M
],
Generic
[
M
]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
list
[
float
]
|
None
,
sliding_window
:
int
|
None
,
kv_cache_dtype
:
str
,
logits_soft_cap
:
float
|
None
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
str
|
None
,
# MLA Specific Arguments
q_lora_rank
:
int
|
None
,
kv_lora_rank
:
int
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
qk_head_dim
:
int
,
v_head_dim
:
int
,
kv_b_proj
:
ColumnParallelLinear
,
indexer
:
object
|
None
=
None
,
q_pad_num_heads
:
int
|
None
=
None
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported for MLA"
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_head_dim
=
qk_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
kv_b_proj
=
kv_b_proj
self
.
indexer
=
indexer
self
.
q_pad_num_heads
=
q_pad_num_heads
if
use_trtllm_ragged_deepseek_prefill
():
logger
.
info_once
(
...
...
@@ -1750,19 +1945,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self
.
dcp_world_size
:
int
=
-
1
self
.
chunked_prefill_workspace_size
=
(
MLACommonMetadataBuilder
.
determine_chunked_prefill_workspace_size
(
get_current_vllm_config
()
)
)
self
.
cp_kv_cache_interleave_size
:
int
=
(
get_current_vllm_config
().
parallel_config
.
cp_kv_cache_interleave_size
)
self
.
_decode_concat_quant_fp8_op
=
_DecodeConcatQuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
compile_native
=
True
,
)
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
return_softmax_lse
=
False
,
softmax_scale
=
None
,
**
kwargs
...
...
@@ -2193,7 +2378,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return
output
,
output_lse
def
_
forward_
prefill
(
def
forward_
mha
(
self
,
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
...
...
@@ -2258,7 +2443,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
output
.
copy_
(
output_prefill
)
@
abstractmethod
def
_
forward_
decode
(
def
forward_
mqa
(
self
,
q
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
...
...
@@ -2266,185 +2451,3 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
layer
:
AttentionLayer
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
raise
NotImplementedError
def
forward
(
self
,
layer
:
AttentionLayer
,
q
:
torch
.
Tensor
,
k_c_normed
:
torch
.
Tensor
,
# key in unified attn
k_pe
:
torch
.
Tensor
,
# value in unified attn
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
M
,
output
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported for MLACommonImpl"
)
if
attn_metadata
is
None
:
# During the profile run try to simulate to worse case output size
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
# since this can be large
_
=
torch
.
empty
(
(
self
.
chunked_prefill_workspace_size
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
,
),
device
=
k_c_normed
.
device
,
dtype
=
k_c_normed
.
dtype
,
)
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return
output
.
fill_
(
0
)
if
self
.
dcp_world_size
==
-
1
:
self
.
dcp_world_size
=
get_dcp_group
().
world_size
fp8_attention
=
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
num_actual_toks
=
attn_metadata
.
num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded
=
output
output
=
output
[:
num_actual_toks
,
...]
q
=
q
[:
num_actual_toks
,
...]
k_c_normed
=
k_c_normed
[:
num_actual_toks
,
...]
k_pe
=
k_pe
[:
num_actual_toks
,
...]
assert
(
attn_metadata
.
num_decodes
is
not
None
and
attn_metadata
.
num_prefills
is
not
None
and
attn_metadata
.
num_decode_tokens
is
not
None
)
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
decode_q
=
q
[:
num_decode_tokens
]
prefill_q
=
q
[
num_decode_tokens
:]
prefill_k_pe
=
k_pe
[
num_decode_tokens
:]
prefill_k_c_normed
=
k_c_normed
[
num_decode_tokens
:]
# write the latent and rope to kv cache
if
kv_cache
.
numel
()
>
0
:
ops
.
concat_and_cache_mla
(
k_c_normed
,
k_pe
.
squeeze
(
1
),
kv_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
=
self
.
kv_cache_dtype
,
scale
=
layer
.
_k_scale
,
)
if
fp8_attention
:
kv_cache
=
kv_cache
.
view
(
current_platform
.
fp8_dtype
())
if
has_prefill
:
self
.
_forward_prefill
(
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
attn_metadata
,
layer
.
_k_scale
,
output
=
output
[
num_decode_tokens
:],
)
if
has_decode
:
assert
attn_metadata
.
decode
is
not
None
decode_q_nope
,
decode_q_pe
=
decode_q
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope
=
decode_q_nope
.
transpose
(
0
,
1
)
if
self
.
q_pad_num_heads
is
not
None
:
B
,
N
,
L
=
decode_q_pe
.
shape
decode_pe_padded
=
decode_q_pe
.
new_empty
((
B
,
self
.
q_pad_num_heads
,
L
))
decode_pe_padded
.
resize_
((
B
,
N
,
L
))
decode_pe_padded
.
copy_
(
decode_q_pe
)
decode_q_pe
=
decode_pe_padded
if
self
.
is_aiter_triton_fp4_bmm_enabled
:
from
aiter.ops.triton.batched_gemm_a16wfp4
import
batched_gemm_a16wfp4
decode_ql_nope
=
batched_gemm_a16wfp4
(
decode_q_nope
,
self
.
W_K
,
self
.
W_K_scale
,
transpose_bm
=
True
,
prequant
=
True
,
y_scale
=
layer
.
_q_scale
if
fp8_attention
else
None
,
)
elif
self
.
is_aiter_triton_fp8_bmm_enabled
:
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
decode_ql_nope
=
rocm_aiter_ops
.
triton_fp8_bmm
(
decode_q_nope
,
self
.
W_K
,
self
.
W_K_scale
,
group_size
=
128
,
transpose_bm
=
True
,
)
else
:
# Pads the head_dim if necessary (for the underlying kernel)
N
,
B
,
P
=
decode_q_nope
.
shape
_
,
_
,
L
=
self
.
W_UK_T
.
shape
if
self
.
q_pad_num_heads
is
not
None
:
decode_ql_nope
=
decode_q_nope
.
new_empty
(
(
self
.
q_pad_num_heads
,
B
,
L
)
)
decode_ql_nope
.
resize_
((
N
,
B
,
L
))
else
:
decode_ql_nope
=
decode_q_nope
.
new_empty
((
N
,
B
,
L
))
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
torch
.
bmm
(
decode_q_nope
,
self
.
W_UK_T
,
out
=
decode_ql_nope
)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
if
fp8_attention
:
assert
decode_ql_nope
.
shape
[
0
]
==
decode_q_pe
.
shape
[
0
]
assert
decode_ql_nope
.
shape
[
1
]
==
decode_q_pe
.
shape
[
1
]
decode_q
=
self
.
_decode_concat_quant_fp8_op
(
decode_ql_nope
,
decode_q_pe
,
layer
.
_q_scale
)
else
:
decode_q
=
(
decode_ql_nope
,
decode_q_pe
)
if
self
.
dcp_world_size
>
1
:
assert
not
fp8_attention
,
"DCP not support fp8 kvcache now."
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
decode_q
=
torch
.
cat
(
decode_q
,
dim
=-
1
)
# decode_q do allgather in head dim.
decode_q
=
get_dcp_group
().
all_gather
(
decode_q
,
dim
=
1
)
# call decode attn
attn_out
,
lse
=
self
.
_forward_decode
(
decode_q
,
kv_cache
,
attn_metadata
,
layer
)
# correct dcp attn_out with lse.
if
self
.
dcp_world_size
>
1
:
attn_out
=
cp_lse_ag_out_rs
(
attn_out
,
lse
,
get_dcp_group
(),
is_lse_base_on_e
=
not
getattr
(
self
,
"_use_fi_prefill"
,
False
),
)
# v_up projection
self
.
_v_up_proj
(
attn_out
,
out
=
output
[:
num_decode_tokens
])
return
output_padded
vllm/v1/attention/backend.py
View file @
aaa901ad
...
...
@@ -67,7 +67,7 @@ class AttentionBackend(ABC):
@
staticmethod
@
abstractmethod
def
get_impl_cls
()
->
type
[
"AttentionImpl"
]:
def
get_impl_cls
()
->
type
[
"AttentionImpl
Base
"
]:
raise
NotImplementedError
@
staticmethod
...
...
@@ -594,7 +594,14 @@ class AttentionLayer(Protocol):
)
->
torch
.
Tensor
:
...
class
AttentionImpl
(
ABC
,
Generic
[
T
]):
class
AttentionImplBase
(
ABC
,
Generic
[
T
]):
"""Base class for attention implementations.
Contains common attributes and initialization logic shared by both
standard AttentionImpl and MLAAttentionImpl. Does not define a forward
method - subclasses define their own forward interfaces.
"""
# Required attributes that all impls should have
num_heads
:
int
head_size
:
int
...
...
@@ -662,6 +669,13 @@ class AttentionImpl(ABC, Generic[T]):
)
return
self
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
pass
class
AttentionImpl
(
AttentionImplBase
[
T
],
Generic
[
T
]):
"""Standard attention implementation with forward method."""
@
abstractmethod
def
__init__
(
self
,
...
...
@@ -704,11 +718,10 @@ class AttentionImpl(ABC, Generic[T]):
"""
return
False
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
pass
class
MLAAttentionImpl
(
AttentionImplBase
[
T
],
Generic
[
T
]):
"""MLA attention implementation with forward_mqa and forward_mha methods."""
class
MLAAttentionImpl
(
AttentionImpl
[
T
],
Generic
[
T
]):
@
abstractmethod
def
__init__
(
self
,
...
...
@@ -731,22 +744,78 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
v_head_dim
:
int
,
kv_b_proj
:
"ColumnParallelLinear"
,
indexer
:
object
|
None
=
None
,
q_pad_num_heads
:
int
|
None
=
None
,
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
forward
(
def
forward
_mha
(
self
,
layer
:
AttentionLayer
,
hidden_states_or_cq
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_
c_and_k_pe_
cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
output
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
k_scale
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
)
->
None
:
"""MHA-style prefill forward pass."""
raise
NotImplementedError
@
abstractmethod
def
forward_mqa
(
self
,
q
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
layer
:
AttentionLayer
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
"""MQA-style decode forward pass."""
raise
NotImplementedError
class
SparseMLAAttentionImpl
(
AttentionImplBase
[
T
],
Generic
[
T
]):
"""Sparse MLA attention implementation with only forward_mqa method.
Sparse MLA implementations only support decode (MQA-style) attention.
They do not support prefill (MHA-style) attention.
"""
@
abstractmethod
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
list
[
float
]
|
None
,
sliding_window
:
int
|
None
,
kv_cache_dtype
:
str
,
logits_soft_cap
:
float
|
None
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
str
|
None
,
# MLA Specific Arguments
q_lora_rank
:
int
|
None
,
kv_lora_rank
:
int
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
qk_head_dim
:
int
,
v_head_dim
:
int
,
kv_b_proj
:
"ColumnParallelLinear"
,
indexer
:
object
|
None
=
None
,
q_pad_num_heads
:
int
|
None
=
None
,
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
forward_mqa
(
self
,
q
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
layer
:
AttentionLayer
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
"""MQA-style decode forward pass."""
raise
NotImplementedError
...
...
vllm/v1/attention/backends/mla/cutlass_mla.py
View file @
aaa901ad
...
...
@@ -244,7 +244,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
return
out
,
lse
def
_
forward_
decode
(
def
forward_
mqa
(
self
,
q
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
...
...
vllm/v1/attention/backends/mla/flashattn_mla.py
View file @
aaa901ad
...
...
@@ -293,7 +293,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
"FlashAttnMLA V1 with FP8 KV cache not yet supported"
)
def
_
forward_
decode
(
def
forward_
mqa
(
self
,
q
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
...
...
vllm/v1/attention/backends/mla/flashinfer_mla.py
View file @
aaa901ad
...
...
@@ -150,7 +150,7 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
self
.
bmm1_scale
:
float
|
None
=
None
self
.
bmm2_scale
:
float
|
None
=
None
def
_
forward_
decode
(
def
forward_
mqa
(
self
,
q
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
aaa901ad
...
...
@@ -234,7 +234,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl"
)
def
_
forward_
decode
(
def
forward_
mqa
(
self
,
q
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
aaa901ad
...
...
@@ -11,7 +11,6 @@ from vllm.config import VllmConfig, get_current_vllm_config
from
vllm.config.cache
import
CacheDType
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention.mla_attention
import
(
MLACommonBaseImpl
,
get_mla_dims
,
)
from
vllm.platforms
import
current_platform
...
...
@@ -25,6 +24,7 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
MultipleOf
,
SparseMLAAttentionImpl
,
)
from
vllm.v1.attention.backends.utils
import
(
reshape_attn_output_for_spec_decode
,
...
...
@@ -686,7 +686,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
return
metadata
class
FlashMLASparseImpl
(
MLACommonBase
Impl
[
FlashMLASparseMetadata
]):
class
FlashMLASparseImpl
(
SparseMLAAttention
Impl
[
FlashMLASparseMetadata
]):
@
staticmethod
def
_compute_fp8_decode_padded_heads
(
num_heads
:
int
)
->
int
:
# FP8 decode kernel only supports h_q = 64 or 128
...
...
@@ -710,19 +710,12 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
indexer
:
"Indexer | None"
=
None
,
**
mla_args
,
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
,
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_lora_rank
:
int
=
mla_args
[
"kv_lora_rank"
]
self
.
softmax_scale
=
scale
assert
indexer
is
not
None
self
.
topk_indices_buffer
:
torch
.
Tensor
|
None
=
indexer
.
topk_indices_buffer
...
...
@@ -974,78 +967,39 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
output
=
output
[:,
:
self
.
num_heads
,
:]
return
output
def
forward
(
def
forward
_mqa
(
self
,
q
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashMLASparseMetadata
,
layer
:
AttentionLayer
,
q
:
torch
.
Tensor
,
k_c_normed
:
torch
.
Tensor
,
# key in unified attn
k_pe
:
torch
.
Tensor
,
# value in unified attn
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashMLASparseMetadata
|
None
,
output
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported for MLACommonImpl"
)
# Concatenate q if it's a tuple (ql_nope, q_pe)
if
isinstance
(
q
,
tuple
):
q
=
torch
.
cat
(
q
,
dim
=-
1
)
if
attn_metadata
is
None
:
# Dummy run - no need to allocate buffers
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return
output
.
fill_
(
0
)
num_actual_toks
=
q
.
shape
[
0
]
num_actual_toks
=
attn_metadata
.
num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q
=
q
[:
num_actual_toks
,
...]
k_c_normed
=
k_c_normed
[:
num_actual_toks
,
...]
k_pe
=
k_pe
[:
num_actual_toks
,
...]
# Get topk indices
assert
self
.
topk_indices_buffer
is
not
None
topk_indices
=
self
.
topk_indices_buffer
[:
num_actual_toks
]
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# Convert from (B, N, P) to (N, B, P)
q_nope
=
q_nope
.
transpose
(
0
,
1
)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope
=
torch
.
bmm
(
q_nope
,
self
.
W_UK_T
)
# Convert from (N, B, L) to (B, N, L)
ql_nope
=
ql_nope
.
transpose
(
0
,
1
)
use_fp8_cache
=
self
.
kv_cache_dtype
==
"fp8_ds_mla"
q
=
torch
.
cat
([
ql_nope
,
q_pe
],
dim
=-
1
)
# write the latent and rope to kv cache
if
kv_cache
.
numel
()
>
0
:
ops
.
concat_and_cache_mla
(
k_c_normed
,
k_pe
.
squeeze
(
1
),
kv_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
=
self
.
kv_cache_dtype
,
scale
=
layer
.
_k_scale
,
)
if
not
use_fp8_cache
:
attn_out
=
self
.
_forward_bf16_kv
(
q
,
kv_cache
,
topk_indices
,
attn_metadata
)
attn_out
=
self
.
_forward_bf16_kv
(
q
,
kv_c_and_k_pe_cache
,
topk_indices
,
attn_metadata
)
elif
attn_metadata
.
fp8_use_mixed_batch
:
attn_out
=
self
.
_forward_fp8_kv_mixed_batch
(
q
,
kv_cache
,
topk_indices
,
attn_metadata
q
,
kv_
c_and_k_pe_
cache
,
topk_indices
,
attn_metadata
)
else
:
attn_out
=
self
.
_forward_fp8_kv_separate_prefill_decode
(
q
,
kv_cache
,
topk_indices
,
attn_metadata
q
,
kv_
c_and_k_pe_
cache
,
topk_indices
,
attn_metadata
)
self
.
_v_up_proj
(
attn_out
,
out
=
output
[:
num_actual_toks
])
return
output
return
attn_out
,
None
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
aaa901ad
...
...
@@ -241,7 +241,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
return
output
def
_
forward_
decode
(
def
forward_
mqa
(
self
,
q
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
View file @
aaa901ad
...
...
@@ -7,12 +7,10 @@ from typing import TYPE_CHECKING, ClassVar
import
numpy
as
np
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention.mla_attention
import
(
MLACommonBaseImpl
,
get_mla_dims
,
)
from
vllm.triton_utils
import
tl
,
triton
...
...
@@ -23,6 +21,7 @@ from vllm.v1.attention.backend import (
AttentionMetadata
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
SparseMLAAttentionImpl
,
)
from
vllm.v1.attention.backends.mla.flashmla_sparse
import
(
triton_convert_req_index_to_global_index
,
...
...
@@ -269,7 +268,7 @@ def reference_mla_sparse_prefill(
return
(
result
,
lse
)
class
ROCMAiterMLASparseImpl
(
MLACommonBase
Impl
[
ROCMAiterMLASparseMetadata
]):
class
ROCMAiterMLASparseImpl
(
SparseMLAAttention
Impl
[
ROCMAiterMLASparseMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
...
...
@@ -287,23 +286,15 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
indexer
:
"Indexer | None"
=
None
,
**
mla_args
,
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
,
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_lora_rank
:
int
=
mla_args
[
"kv_lora_rank"
]
self
.
softmax_scale
=
scale
assert
indexer
is
not
None
self
.
topk_indices_buffer
:
torch
.
Tensor
|
None
=
indexer
.
topk_indices_buffer
self
.
is_fp8bmm_enabled
=
rocm_aiter_ops
.
is_fp8bmm_enabled
()
def
_forward_bf16_kv
(
self
,
...
...
@@ -342,56 +333,23 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
return
output
[:,
:
self
.
num_heads
,
:]
def
forward
(
def
forward
_mqa
(
self
,
layer
:
AttentionLayer
,
q
:
torch
.
Tensor
,
k_c_normed
:
torch
.
Tensor
,
# key in unified attn
k_pe
:
torch
.
Tensor
,
# value in unified attn
kv_cache
:
torch
.
Tensor
,
q
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
ROCMAiterMLASparseMetadata
,
output
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
layer
:
AttentionLayer
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported for ROCMAiterMLASparse"
)
if
attn_metadata
is
None
:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return
output
.
fill_
(
0
)
num_actual_toks
=
attn_metadata
.
num_actual_tokens
# Concatenate q if it's a tuple (ql_nope, q_pe)
if
isinstance
(
q
,
tuple
):
q
=
torch
.
cat
(
q
,
dim
=-
1
)
# Inputs and outputs may be padded for CUDA graphs
q
=
q
[:
num_actual_toks
,
...]
k_c_normed
=
k_c_normed
[:
num_actual_toks
,
...]
k_pe
=
k_pe
[:
num_actual_toks
,
...]
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# Convert from (B, N, P) to (N, B, P)
q_nope
=
q_nope
.
transpose
(
0
,
1
)
if
self
.
is_fp8bmm_enabled
:
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
ql_nope
=
rocm_aiter_ops
.
triton_fp8_bmm
(
q_nope
,
self
.
W_K
,
self
.
W_K_scale
,
group_size
=
128
,
transpose_bm
=
True
)
else
:
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope
=
torch
.
bmm
(
q_nope
,
self
.
W_UK_T
)
# Convert from (N, B, L) to (B, N, L)
ql_nope
=
ql_nope
.
transpose
(
0
,
1
)
num_actual_toks
=
q
.
shape
[
0
]
# Get topk indices
assert
self
.
topk_indices_buffer
is
not
None
topk_indices
=
self
.
topk_indices_buffer
[:
num_actual_toks
]
...
...
@@ -403,22 +361,8 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
NUM_TOPK_TOKENS
=
attn_metadata
.
topk_tokens
,
)
q
=
torch
.
cat
([
ql_nope
,
q_pe
],
dim
=-
1
)
# write the latent and rope to kv cache
if
kv_cache
.
numel
()
>
0
:
ops
.
concat_and_cache_mla
(
k_c_normed
,
k_pe
.
squeeze
(
1
),
kv_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
=
self
.
kv_cache_dtype
,
scale
=
layer
.
_k_scale
,
)
attn_out
=
self
.
_forward_bf16_kv
(
q
,
kv_cache
,
topk_indices_global
,
attn_metadata
q
,
kv_
c_and_k_pe_
cache
,
topk_indices_global
,
attn_metadata
)
self
.
_v_up_proj
(
attn_out
,
out
=
output
[:
num_actual_toks
])
return
output
return
attn_out
,
None
vllm/v1/attention/backends/mla/triton_mla.py
View file @
aaa901ad
...
...
@@ -110,7 +110,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
**
kwargs
,
)
def
_
forward_
decode
(
def
forward_
mqa
(
self
,
q
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment