Unverified Commit 02cabff2 authored by TJian's avatar TJian Committed by GitHub
Browse files

[V1] [ROCm] Enable EP with AITER Fused MoE (#20270)


Signed-off-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
parent 3d19d47d
......@@ -646,13 +646,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
indices_type=self.topk_indices_dtype)
if self.rocm_aiter_moe_enabled:
assert expert_map is None
return self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
......
......@@ -315,7 +315,8 @@ def rocm_aiter_fused_experts(
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> torch.Tensor:
block_shape: Optional[list[int]] = None,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
activation_method = (ActivationMethod.SILU
if activation == "silu" else ActivationMethod.GELU)
......@@ -323,6 +324,11 @@ def rocm_aiter_fused_experts(
topk_weights = topk_weights.to(torch.float32)
topk_ids = topk_ids.to(torch.int32)
if expert_map is not None:
expert_mask = (expert_map > -1).to(torch.int32)
else:
expert_mask = None
# w8a8 per-channel quantization
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
......@@ -346,7 +352,7 @@ def rocm_aiter_fused_experts(
fc2_smooth_scale=None,
a16=False,
per_tensor_quant_scale=None,
expert_mask=None,
expert_mask=expert_mask,
activation_method=activation_method)
else:
......@@ -378,6 +384,7 @@ def rocm_aiter_fused_experts(
w2,
topk_weights,
topk_ids,
expert_mask=expert_mask,
quant_method=quant_method,
activation_method=activation_method,
w1_scale=w1_scale,
......
......@@ -633,7 +633,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
a2_scale=layer.w2_input_scale,
expert_map=expert_map)
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
......
......@@ -442,6 +442,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"""
def __init__(self, quant_config: Fp8Config):
from vllm.model_executor.layers.fused_moe import fused_experts
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
......@@ -879,7 +880,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size)
block_shape=self.quant_config.weight_block_size,
expert_map=expert_map)
elif self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment