Unverified Commit 85c0950b authored by Tan Pin Siang's avatar Tan Pin Siang Committed by GitHub
Browse files

[ROCm] Enable MORI EP for unquantized MoE with AITER backend (#37529)


Signed-off-by: default avatarTan Pin Siang <pinsiang.tan@amd.com>
parent 57861ae4
...@@ -186,16 +186,23 @@ def maybe_make_prepare_finalize( ...@@ -186,16 +186,23 @@ def maybe_make_prepare_finalize(
use_fp8_dispatch = ( use_fp8_dispatch = (
quant_config.is_per_act_token or quant_config.is_block_quantized quant_config.is_per_act_token or quant_config.is_block_quantized
) )
# For PTPC (per token per channel) quant, the scale dim for each token is 1 if use_fp8_dispatch:
# For 1x128 quant, the scale dim for each token is hidden_dim // 128 # For PTPC (per token per channel) quant, scale dim is 1
scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128 # For 1x128 quant, scale dim is hidden_dim // 128
quant_dtype = quant_config.quant_dtype
scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128
else:
# Unquantized dispatch (e.g. AITER with defer_input_quant):
# dispatch raw BF16/FP16 data, no scales needed.
quant_dtype = moe.in_dtype
scale_dim = 0
all_to_all_args = dict( all_to_all_args = dict(
rank=all2all_manager.rank, rank=all2all_manager.rank,
num_ep_ranks=all2all_manager.world_size, num_ep_ranks=all2all_manager.world_size,
quant_dtype=quant_config.quant_dtype, quant_dtype=quant_dtype,
token_hidden_size=moe.hidden_dim, token_hidden_size=moe.hidden_dim,
scale_dim=scale_dim, scale_dim=scale_dim,
scale_type_size=torch.float32.itemsize, scale_type_size=0 if scale_dim == 0 else torch.float32.itemsize,
max_num_tokens_per_dp_rank=moe.max_num_tokens, max_num_tokens_per_dp_rank=moe.max_num_tokens,
input_dtype=moe.in_dtype, input_dtype=moe.in_dtype,
num_local_experts=moe.num_experts // all2all_manager.world_size, num_local_experts=moe.num_experts // all2all_manager.world_size,
......
...@@ -108,10 +108,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -108,10 +108,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalizeModular | None: ) -> FusedMoEPrepareAndFinalizeModular | None:
if self.unquantized_backend == UnquantizedMoeBackend.AITER: return super().maybe_make_prepare_finalize(routing_tables)
return None
else:
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
...@@ -130,6 +127,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -130,6 +127,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
max_num_tokens=self.moe.max_num_tokens, max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
) )
elif (
self.unquantized_backend == UnquantizedMoeBackend.AITER
and rocm_aiter_ops.is_fused_moe_enabled()
):
from .rocm_aiter_fused_moe import AiterExperts
logger.debug("AiterExperts %s", self.moe)
return AiterExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
)
else: else:
logger.debug("TritonExperts %s", self.moe) logger.debug("TritonExperts %s", self.moe)
return TritonExperts( return TritonExperts(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment