Unverified Commit 3a8eef58 authored by Rohan Potdar's avatar Rohan Potdar Committed by GitHub
Browse files

[ROCm][Bugfix]: Disable AITER Triton ROPE by default (#35601)


Signed-off-by: default avatarRohan138 <rohanpotdar138@gmail.com>
parent 97995f63
...@@ -106,7 +106,7 @@ if TYPE_CHECKING: ...@@ -106,7 +106,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = True VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_FP4BMM: bool = True VLLM_ROCM_USE_AITER_FP4BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
...@@ -949,9 +949,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -949,9 +949,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in ("true", "1") os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in ("true", "1")
), ),
# Whether to use aiter rope. # Whether to use aiter rope.
# By default is enabled. # By default is disabled.
"VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: ( "VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "True").lower() in ("true", "1") os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1")
), ),
# Whether to use aiter triton fp8 bmm kernel # Whether to use aiter triton fp8 bmm kernel
# By default is enabled. # By default is enabled.
......
...@@ -592,7 +592,6 @@ class RocmPlatform(Platform): ...@@ -592,7 +592,6 @@ class RocmPlatform(Platform):
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled() use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled() use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
use_aiter_triton_rope = rocm_aiter_ops.is_triton_rotary_embed_enabled()
# Aiter rms norm perform best when CUDA Graph capture is enabled. # Aiter rms norm perform best when CUDA Graph capture is enabled.
if ( if (
use_aiter_rms_norm use_aiter_rms_norm
...@@ -619,9 +618,9 @@ class RocmPlatform(Platform): ...@@ -619,9 +618,9 @@ class RocmPlatform(Platform):
and "-grouped_topk" not in compilation_config.custom_ops and "-grouped_topk" not in compilation_config.custom_ops
): ):
compilation_config.custom_ops.append("+grouped_topk") compilation_config.custom_ops.append("+grouped_topk")
# Enable rotary embedding when using AITER if its not disabled by user # Enable rotary embedding customop when using AITER if not disabled by user
if ( if (
use_aiter_triton_rope rocm_aiter_ops.is_enabled()
and "+rotary_embedding" not in compilation_config.custom_ops and "+rotary_embedding" not in compilation_config.custom_ops
and "-rotary_embedding" not in compilation_config.custom_ops and "-rotary_embedding" not in compilation_config.custom_ops
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment