Unverified Commit 98a3a810 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[ROCm] Add attention sink to use_rocm_custom_paged_attention (#22329)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarLiuXiaoxuanPKU <lilyliupku@gmail.com>
Co-authored-by: default avatarsimon-mo <xmo@berkeley.edu>
Co-authored-by: default avatarChen Zhang <zhangch99@outlook.com>
Co-authored-by: default avatarHongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: default avatarMinseok Lee <47620120+minseokl@users.noreply.github.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
parent de98252f
...@@ -127,7 +127,8 @@ def use_rocm_custom_paged_attention( ...@@ -127,7 +127,8 @@ def use_rocm_custom_paged_attention(
max_seq_len: int, max_seq_len: int,
sliding_window: int, sliding_window: int,
kv_cache_dtype: str, kv_cache_dtype: str,
alibi_slopes: Optional[torch.Tensor] = None) -> bool: alibi_slopes: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
...@@ -145,7 +146,7 @@ def use_rocm_custom_paged_attention( ...@@ -145,7 +146,7 @@ def use_rocm_custom_paged_attention(
and max_seq_len <= 128 * 1024 and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER)) and envs.VLLM_ROCM_USE_AITER) and sinks is None)
else: else:
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0 return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
...@@ -155,7 +156,7 @@ def use_rocm_custom_paged_attention( ...@@ -155,7 +156,7 @@ def use_rocm_custom_paged_attention(
and (gqa_ratio >= 3 and gqa_ratio <= 16) and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024 and alibi_slopes is None and max_seq_len <= 128 * 1024 and alibi_slopes is None
and kv_cache_dtype == "auto" and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
class RocmPlatform(Platform): class RocmPlatform(Platform):
...@@ -170,7 +171,7 @@ class RocmPlatform(Platform): ...@@ -170,7 +171,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [ supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8" "quark", "ptpc_fp8", "mxfp4"
] ]
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment