Unverified Commit 8da2f28f authored by Pleaplusone's avatar Pleaplusone Committed by GitHub
Browse files

[ROCm][BugFix]Fix `get_cu_count` in rocm_aiter_fa.py (#28618)


Signed-off-by: default avatarganyi <ygan@amd.com>
parent 86d15bfd
......@@ -18,6 +18,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import get_cu_count
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
......@@ -38,7 +39,7 @@ if current_platform.is_rocm():
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
def num_programs(total_tokens):
return min(total_tokens, current_platform.get_cu_count())
return min(total_tokens, get_cu_count())
@triton.jit
def cp_mha_gather_cache_kernel(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment