Unverified Commit 1a834df2 authored by Pleaplusone's avatar Pleaplusone Committed by GitHub
Browse files

[ROCm][Bugfix] Fix accuracy issue on fmoe when...


[ROCm][Bugfix] Fix accuracy issue on fmoe when `VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS` enabled (#31523)
Signed-off-by: default avatarganyi <ygan@amd.com>
parent 51085c2a
...@@ -407,8 +407,10 @@ class RocmPlatform(Platform): ...@@ -407,8 +407,10 @@ class RocmPlatform(Platform):
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
is_eager_execution = compilation_config == CUDAGraphMode.NONE is_eager_execution = compilation_config == CUDAGraphMode.NONE
use_aiter_fused_moe = rocm_aiter_ops.is_fused_moe_enabled()
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled() use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled() use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
if compilation_config.cudagraph_mode.has_full_cudagraphs(): if compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs # decode context parallel does not support full cudagraphs
...@@ -458,6 +460,22 @@ class RocmPlatform(Platform): ...@@ -458,6 +460,22 @@ class RocmPlatform(Platform):
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops: if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
compilation_config.custom_ops.append("+quant_fp8") compilation_config.custom_ops.append("+quant_fp8")
if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
logger.warning_once(
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
"requires the 'grouped_topk' custom op. Overriding the "
"user-provided '-grouped_topk'."
)
compilation_config.custom_ops.remove("-grouped_topk")
# Ensure grouped_topk is always enabled when using AITER if
# its not disabled by user
if (
use_aiter_fused_moe
and "+grouped_topk" not in compilation_config.custom_ops
and "-grouped_topk" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+grouped_topk")
@classmethod @classmethod
def verify_model_arch(cls, model_arch: str) -> None: def verify_model_arch(cls, model_arch: str) -> None:
if model_arch in _ROCM_UNSUPPORTED_MODELS: if model_arch in _ROCM_UNSUPPORTED_MODELS:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment