Unverified Commit 6a9cceb2 authored by Duyi-Wang's avatar Duyi-Wang Committed by GitHub
Browse files

[Bugfix][ROCm] Fix MoRI + AITER FP8 dispatch compatibility for defer_input_quant (#37418)


Signed-off-by: default avatarDuyi-Wang <duyi.wang@amd.com>
parent 199f9141
...@@ -70,16 +70,13 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): ...@@ -70,16 +70,13 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
- Optional dispatched expert topk IDs - Optional dispatched expert topk IDs
- Optional dispatched expert topk weight - Optional dispatched expert topk weight
""" """
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
"mori does not support apply_router_weight_on_input=True now." "mori does not support apply_router_weight_on_input=True now."
) )
scale = None scale = None
if self.use_fp8_dispatch: # When defer_input_quant is True, the expert kernel handles
# quantization internally, so skip FP8 dispatch quantization.
if self.use_fp8_dispatch and not defer_input_quant:
from aiter import QuantType, get_hip_quant from aiter import QuantType, get_hip_quant
if quant_config.is_block_quantized: if quant_config.is_block_quantized:
......
...@@ -295,7 +295,12 @@ def rocm_aiter_fused_experts( ...@@ -295,7 +295,12 @@ def rocm_aiter_fused_experts(
class AiterExperts(mk.FusedMoEExpertsModular): class AiterExperts(mk.FusedMoEExpertsModular):
@property @property
def expects_unquantized_inputs(self) -> bool: def expects_unquantized_inputs(self) -> bool:
return True # When paired with MoRI, the prepare/finalize handles FP8
# quantization during dispatch to reduce network traffic,
# so we should not defer input quantization.
# Otherwise, AITER fused MoE kernels handle input quantization
# internally via a single fused kernel.
return not self.moe_config.use_mori_kernels
@staticmethod @staticmethod
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment