Unverified Commit 96bf50a2 authored by vllmellm's avatar vllmellm Committed by GitHub
Browse files

[ROCm] Serving Fails on Radeon Due to AITER Dtype Import (#30952)


Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
parent f90d3636
......@@ -24,14 +24,13 @@ def is_aiter_found() -> bool:
# we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND = is_aiter_found()
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if IS_AITER_FOUND:
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
def is_aiter_found_and_supported() -> bool:
if current_platform.is_rocm() and IS_AITER_FOUND:
from vllm.platforms.rocm import on_gfx9
return on_gfx9()
return False
def if_aiter_supported(func: Callable) -> Callable:
......@@ -43,17 +42,24 @@ def if_aiter_supported(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
# checks the platform, device arch and aiter library existence.
if current_platform.is_rocm() and IS_AITER_FOUND:
from vllm.platforms.rocm import on_gfx9
if on_gfx9():
return func(*args, **kwargs)
if is_aiter_found_and_supported():
return func(*args, **kwargs)
return None
return wrapper
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if is_aiter_found_and_supported():
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment