Unverified Commit 7243e02a authored by larryli2-amd's avatar larryli2-amd Committed by GitHub
Browse files

[ROCm][Feature] Enable AITER MLA attention backend to work with Eagle3...


[ROCm][Feature] Enable AITER MLA attention backend to work with Eagle3 speculative decoding on ROCm (#39616)
Signed-off-by: default avatarlarryli2-amd <larryli2@amd.com>
Co-authored-by: default avatarTJian <tunjian.tan@embeddedllm.com>
parent def8f522
...@@ -215,7 +215,7 @@ configuration. ...@@ -215,7 +215,7 @@ configuration.
| `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | | `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
......
...@@ -44,7 +44,11 @@ class AiterMLABackend(MLACommonBackend): ...@@ -44,7 +44,11 @@ class AiterMLABackend(MLACommonBackend):
@staticmethod @staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1] # The aiter MLA decode kernel always operates with page_size=1
# internally (the wrapper flattens kv_buffer via .view(-1, 1, 1, H)).
# We support any kernel_block_size by expanding block-level indices
# into per-token flat indices in the metadata builder.
return [MultipleOf(1)]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
...@@ -74,6 +78,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): ...@@ -74,6 +78,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
attn_out_dtype: torch.dtype = torch.bfloat16 attn_out_dtype: torch.dtype = torch.bfloat16
# The max query output length: int # The max query output length: int
max_qo_len: int | None = None max_qo_len: int | None = None
# Whether persistent MLA metadata was computed (only for qseqlen=1)
has_persistent_metadata: bool = False
@dataclass @dataclass
...@@ -105,7 +111,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -105,7 +111,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.decode_attn_out_dtype = vllm_config.model_config.dtype self.decode_attn_out_dtype = vllm_config.model_config.dtype
# kernel block size is always 1.
# Store the kernel block size from the spec. When kernel_block_size=1
# (no spec-dec), behavior is identical to the original. When > 1
# (e.g. 16 with Eagle3), we expand block-level indices into per-token
# flat indices since the aiter kernel always uses page_size=1 internally.
self.kernel_block_size = kv_cache_spec.block_size
# In the flat view (.view(-1,1,1,H)), each token is its own page,
# so max_num_pages_per_req = max_model_len regardless of
# kernel_block_size.
max_num_pages_per_req = vllm_config.model_config.max_model_len max_num_pages_per_req = vllm_config.model_config.max_model_len
max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req max_num_pages = max_num_reqs * max_num_pages_per_req
...@@ -115,8 +130,9 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -115,8 +130,9 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# so we can only use the persistent buffer if a cudagraph is actually # so we can only use the persistent buffer if a cudagraph is actually
# being used. # being used.
# paged_kv_last_page_len is always 1s (kernel block size is always 1), # paged_kv_last_page_len is always 1s (the aiter kernel always sees
# so we create it once and reuse slices in both eager and cudagraph modes. # page_size=1 after .view(-1,1,1,H) flattening), so we create it
# once and reuse slices in both eager and cudagraph modes.
self.paged_kv_last_page_len = torch.ones( self.paged_kv_last_page_len = torch.ones(
max_num_reqs, dtype=torch.int32, device=device max_num_reqs, dtype=torch.int32, device=device
) )
...@@ -196,14 +212,14 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -196,14 +212,14 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
num_decode_tokens: int, num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None, dcp_tot_seq_lens_device: torch.Tensor | None,
) -> AiterMLADecodeMetadata: ) -> AiterMLADecodeMetadata:
# kernel block size is always 1, although the kv block size is not 1.
device = self.device device = self.device
num_reqs = seq_lens_device.size(0) num_reqs = seq_lens_device.size(0)
# kernel block size is always 1, so each page has exactly 1 token. # The aiter kernel always operates with page_size=1 (the wrapper
# last_page_len is always 1 - just slice the pre-initialized buffer. # flattens kv_buffer). last_page_len is always 1.
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
# indptr: cumsum of seq_lens (one page per token in the flat view)
paged_kv_indptr = torch.cat( paged_kv_indptr = torch.cat(
[ [
torch.zeros(1, dtype=seq_lens_device.dtype, device=device), torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
...@@ -215,11 +231,19 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -215,11 +231,19 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.paged_kv_indices.fill_(-1) self.paged_kv_indices.fill_(-1)
_copy_page_indices_kernel[(num_reqs,)](
# Expand block_table entries into per-token flat indices.
# When kernel_block_size=1, this degrades to a direct copy (identical
# to the original _copy_page_indices_kernel).
# When kernel_block_size=K>1, block_table entry b covering K tokens
# gets expanded to flat indices b*K, b*K+1, ..., b*K+(K-1).
_expand_page_indices_kernel[(num_reqs,)](
self.paged_kv_indices, self.paged_kv_indices,
block_table_tensor, block_table_tensor,
block_table_tensor.stride(0), block_table_tensor.stride(0),
paged_kv_indptr, paged_kv_indptr,
seq_lens_device,
KERNEL_BLOCK_SIZE=self.kernel_block_size,
BLOCK_SIZE=1024, BLOCK_SIZE=1024,
) )
paged_kv_indices = self.paged_kv_indices paged_kv_indices = self.paged_kv_indices
...@@ -245,27 +269,37 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -245,27 +269,37 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
0, num_reqs + 1, step=1, dtype=torch.int32, device=device 0, num_reqs + 1, step=1, dtype=torch.int32, device=device
) )
from aiter import get_mla_metadata_v1 # The aiter MLA ASM kernel only supports qseqlen=1 (single-token
# decode). With speculative decoding, the verification step has
get_mla_metadata_v1( # qseqlen > 1 (e.g. 8 for spec7). get_mla_metadata_v1 calls
qo_indptr, # get_heuristic_kernel_mla which fails for qseqlen > 1.
paged_kv_indptr, # We track whether persistent metadata was successfully computed
paged_kv_last_page_len, # so forward_mqa can skip passing it (falling back to the kernel
self._num_attention_heads, # computing its own metadata internally, like v0.18.0).
1, has_persistent_metadata = False
True, if max_qo_len == 1:
self._mla_work_meta_data, from aiter import get_mla_metadata_v1
self._mla_work_info_set,
self._mla_work_indptr, get_mla_metadata_v1(
self._mla_reduce_indptr, qo_indptr,
self._mla_reduce_final_map, paged_kv_indptr,
self._mla_reduce_partial_map, paged_kv_last_page_len,
page_size=1, self._num_attention_heads,
kv_granularity=16, 1,
max_seqlen_qo=max_qo_len, True,
uni_seqlen_qo=max_qo_len, self._mla_work_meta_data,
fast_mode=True, self._mla_work_info_set,
) self._mla_work_indptr,
self._mla_reduce_indptr,
self._mla_reduce_final_map,
self._mla_reduce_partial_map,
page_size=1,
kv_granularity=16,
max_seqlen_qo=max_qo_len,
uni_seqlen_qo=max_qo_len,
fast_mode=True,
)
has_persistent_metadata = True
attn_metadata = AiterMLADecodeMetadata( attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor, block_table=block_table_tensor,
...@@ -277,6 +311,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -277,6 +311,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
dcp_tot_seq_lens=dcp_tot_seq_lens_device, dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_qo_len=max_qo_len, max_qo_len=max_qo_len,
attn_out_dtype=self.decode_attn_out_dtype, attn_out_dtype=self.decode_attn_out_dtype,
has_persistent_metadata=has_persistent_metadata,
) )
return attn_metadata return attn_metadata
...@@ -290,41 +325,67 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -290,41 +325,67 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
attn_metadata = super().build( attn_metadata = super().build(
common_prefix_len, common_attn_metadata, fast_build common_prefix_len, common_attn_metadata, fast_build
) )
attn_metadata.work_meta_data = self._mla_work_meta_data if (
attn_metadata.work_indptr = self._mla_work_indptr attn_metadata.decode is not None
attn_metadata.work_info_set = self._mla_work_info_set and attn_metadata.decode.has_persistent_metadata
attn_metadata.reduce_indptr = self._mla_reduce_indptr ):
attn_metadata.reduce_final_map = self._mla_reduce_final_map attn_metadata.work_meta_data = self._mla_work_meta_data
attn_metadata.reduce_partial_map = self._mla_reduce_partial_map attn_metadata.work_indptr = self._mla_work_indptr
attn_metadata.work_info_set = self._mla_work_info_set
attn_metadata.reduce_indptr = self._mla_reduce_indptr
attn_metadata.reduce_final_map = self._mla_reduce_final_map
attn_metadata.reduce_partial_map = self._mla_reduce_partial_map
return attn_metadata return attn_metadata
@triton.jit @triton.jit
def _copy_page_indices_kernel( def _expand_page_indices_kernel(
page_indices, page_indices,
block_table, block_table,
block_table_stride, block_table_stride,
cu_num_blocks, cu_num_tokens,
seq_lens,
KERNEL_BLOCK_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
"""Copy block table rows into a flat page_indices buffer using indptr. """Expand block table entries into per-token flat page indices.
Avoids blocking boolean mask indexing (tensor[mask]) which has
data-dependent output size and forces sync. The aiter MLA kernel always operates with page_size=1 internally
This is the same kernel as introduced in backends/flashinfer.py. (kv_buffer is flattened via .view(-1, 1, 1, H)). This kernel converts
block-level indices from the block table into individual token positions
in the flattened KV buffer.
When KERNEL_BLOCK_SIZE=1: block_idx=t, offset=0, flat=block_id
(equivalent to a direct copy -- no regression from the original kernel).
When KERNEL_BLOCK_SIZE=K: block table entry b (covering K tokens)
is expanded to flat indices b*K, b*K+1, ..., b*K+(K-1).
""" """
req_idx = tl.program_id(0) req_idx = tl.program_id(0)
row_ptr = block_table + req_idx * block_table_stride row_ptr = block_table + req_idx * block_table_stride
start_idx = tl.load(cu_num_blocks + req_idx) start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_blocks + req_idx + 1) num_tokens = tl.load(seq_lens + req_idx)
num_blocks = end_idx - start_idx
offset = tl.arange(0, BLOCK_SIZE) offset = tl.arange(0, BLOCK_SIZE)
for i in tl.range(0, num_blocks, BLOCK_SIZE): for i in tl.range(0, num_tokens, BLOCK_SIZE):
block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks) token_offsets = i + offset
mask = token_offsets < num_tokens
# Which block in the block table does this token belong to?
block_idx = token_offsets // KERNEL_BLOCK_SIZE
# Offset within that block
offset_in_block = token_offsets % KERNEL_BLOCK_SIZE
# Load the block ID from the block table
block_ids = tl.load(row_ptr + block_idx, mask=mask)
# Compute flat index in the flattened kv_buffer
flat_indices = block_ids * KERNEL_BLOCK_SIZE + offset_in_block
tl.store( tl.store(
page_indices + start_idx + i + offset, page_indices + start_idx + token_offsets,
block_ids, flat_indices,
mask=i + offset < num_blocks, mask=mask,
) )
...@@ -426,6 +487,24 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -426,6 +487,24 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
# Build kwargs for mla_decode_fwd. Pass persistent metadata only
# when it was successfully computed (qseqlen=1 decode steps).
# For multi-token verification steps (spec-dec), the kernel falls
# back to computing metadata internally.
mla_kwargs = dict(
q_scale=layer._q_scale,
kv_scale=layer._k_scale,
)
if attn_metadata.work_meta_data is not None:
mla_kwargs.update(
work_meta_data=attn_metadata.work_meta_data,
work_indptr=attn_metadata.work_indptr,
work_info_set=attn_metadata.work_info_set,
reduce_indptr=attn_metadata.reduce_indptr,
reduce_final_map=attn_metadata.reduce_final_map,
reduce_partial_map=attn_metadata.reduce_partial_map,
)
rocm_aiter_ops.mla_decode_fwd( rocm_aiter_ops.mla_decode_fwd(
q, q,
kv_buffer, kv_buffer,
...@@ -436,14 +515,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -436,14 +515,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len, attn_metadata.decode.paged_kv_last_page_len,
q_scale=layer._q_scale, **mla_kwargs,
kv_scale=layer._k_scale,
work_meta_data=attn_metadata.work_meta_data,
work_indptr=attn_metadata.work_indptr,
work_info_set=attn_metadata.work_info_set,
reduce_indptr=attn_metadata.reduce_indptr,
reduce_final_map=attn_metadata.reduce_final_map,
reduce_partial_map=attn_metadata.reduce_partial_map,
) )
if self._needs_head_repeat: if self._needs_head_repeat:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment