"examples/vscode:/vscode.git/clone" did not exist on "6b1391ca7ece0350ecbe07e91f2c91d901c7f525"
Unverified Commit 107cf8e9 authored by Rabi Mishra's avatar Rabi Mishra Committed by GitHub
Browse files

fix(rocm): Add get_supported_kernel_block_sizes() to ROCM_ATTN (#31712)


Signed-off-by: default avatarrabi <ramishra@redhat.com>
parent 63baa28c
...@@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import ( ...@@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionType, AttentionType,
MultipleOf,
) )
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
...@@ -158,6 +159,13 @@ class RocmAttentionBackend(AttentionBackend): ...@@ -158,6 +159,13 @@ class RocmAttentionBackend(AttentionBackend):
torch.float32, torch.float32,
] ]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
# ROCM paged attention kernel only supports block sizes 16 and 32
# due to shared memory (LDS) constraints on AMD GPUs.
# See csrc/rocm/attention.cu CALL_CUSTOM_LAUNCHER_BLK macro.
return [16, 32]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256] return [32, 64, 96, 128, 160, 192, 224, 256]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment