Unverified Commit 3fb0d909 authored by Qiang Zhang's avatar Qiang Zhang Committed by GitHub
Browse files

[AMD] Use Decoupled Kernel Block Size to Support AITER MLA block_size=1 (#27715)


Signed-off-by: default avatarchiangzhang <chiangzhang@tencent.com>
parent 05c2dee7
...@@ -119,14 +119,12 @@ class AttentionBackend(ABC): ...@@ -119,14 +119,12 @@ class AttentionBackend(ABC):
return True return True
for supported_size in cls.supported_kernel_block_sizes: for supported_size in cls.supported_kernel_block_sizes:
is_multiple_of = ( if isinstance(supported_size, MultipleOf):
isinstance(supported_size, MultipleOf) supported_size = supported_size.base
and block_size % supported_size.base == 0 # With hybrid_blocks feature, the framework-level block size
) # only needs to be a multiple of the kernel's requirement,
is_int_equal = ( # even if the kernel requires a fixed block_size.
isinstance(supported_size, int) and block_size == supported_size if block_size % supported_size == 0:
)
if is_multiple_of or is_int_equal:
return True return True
return False return False
......
...@@ -7,9 +7,8 @@ from typing import ClassVar ...@@ -7,9 +7,8 @@ from typing import ClassVar
import torch import torch
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import AttentionLayer from vllm.attention.backends.abstract import AttentionLayer, MultipleOf
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonDecodeMetadata, MLACommonDecodeMetadata,
...@@ -22,6 +21,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec ...@@ -22,6 +21,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec
class AiterMLABackend(MLACommonBackend): class AiterMLABackend(MLACommonBackend):
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [1]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "ROCM_AITER_MLA" return "ROCM_AITER_MLA"
...@@ -71,9 +72,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -71,9 +72,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
) )
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv( # kernel block size is always 1.
vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size max_num_pages_per_req = vllm_config.model_config.max_model_len
)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req max_num_pages = max_num_reqs * max_num_pages_per_req
...@@ -82,11 +82,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -82,11 +82,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# so we can only use the persistent buffer if a cudagraph is actually # so we can only use the persistent buffer if a cudagraph is actually
# being used. # being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.block_table_remapping = torch.zeros(
[max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size],
dtype=torch.int32,
device=device,
)
self.paged_kv_indptr = torch.zeros( self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device max_num_reqs + 1, dtype=torch.int32, device=device
) )
...@@ -111,36 +106,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -111,36 +106,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
num_decode_tokens: int, num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None, dcp_tot_seq_lens_device: torch.Tensor | None,
) -> AiterMLADecodeMetadata: ) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size # kernel block size is always 1, although the kv block size is not 1.
device = self.device device = self.device
num_reqs = seq_lens_device.size(0) num_reqs = seq_lens_device.size(0)
bs, _ = block_table_tensor.shape
block_table_tensor = (
block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size
)
block_table_tensor = (
block_table_tensor
+ torch.arange(
0,
page_size,
device=block_table_tensor.device,
dtype=block_table_tensor.dtype,
)[None, None, :]
)
block_table_tensor = block_table_tensor.view(bs, -1)
# after remapping, we assume the block size already equals to 1
max_blk_size_per_req = block_table_tensor.shape[-1]
mask = torch.arange( mask = torch.arange(
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
).unsqueeze(0) < seq_lens_device.unsqueeze(1) ).unsqueeze(0) < seq_lens_device.unsqueeze(1)
paged_kv_indices = block_table_tensor[mask] paged_kv_indices = block_table_tensor[mask]
paged_kv_last_page_len = seq_lens_device % page_size paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device)
paged_kv_last_page_len = torch.where(
paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len
)
paged_kv_indptr = torch.cat( paged_kv_indptr = torch.cat(
[ [
...@@ -151,12 +126,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -151,12 +126,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0) num_actual_pages = paged_kv_indices.size(0)
self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_(
block_table_tensor, non_blocking=True
)
block_table_tensor = self.block_table_remapping[
:num_reqs, :max_blk_size_per_req
]
self.paged_kv_indices[:num_actual_pages].copy_( self.paged_kv_indices[:num_actual_pages].copy_(
paged_kv_indices, non_blocking=True paged_kv_indices, non_blocking=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment