Unverified Commit 1aaecda0 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[XPU] Enable Expert parallel for MoE models (#28263)


Signed-off-by: default avatarYan Ma <yan.ma@intel.com>
Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 811df41e
...@@ -642,10 +642,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -642,10 +642,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if current_platform.is_xpu(): if current_platform.is_xpu():
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
use_prepack=True, use_prepack=True,
experts_start_id=ep_rank_start,
) )
elif current_platform.is_cpu(): elif current_platform.is_cpu():
from vllm.model_executor.layers.fused_moe import cpu_fused_moe from vllm.model_executor.layers.fused_moe import cpu_fused_moe
......
...@@ -399,6 +399,7 @@ class XPUFp8MoEMethod(FusedMoEMethodBase): ...@@ -399,6 +399,7 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -407,6 +408,7 @@ class XPUFp8MoEMethod(FusedMoEMethodBase): ...@@ -407,6 +408,7 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
a1_scale_inv=layer.w13_input_scale, a1_scale_inv=layer.w13_input_scale,
a2_scale_inv=layer.w2_input_scale, a2_scale_inv=layer.w2_input_scale,
use_prepack=True, use_prepack=True,
experts_start_id=ep_rank_start,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
......
...@@ -1113,6 +1113,7 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod): ...@@ -1113,6 +1113,7 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
layer.w13_weight.data = layer.w13_weight.data.view(torch.int32) layer.w13_weight.data = layer.w13_weight.data.view(torch.int32)
layer.w2_weight.data = layer.w2_weight.data.view(torch.int32) layer.w2_weight.data = layer.w2_weight.data.view(torch.int32)
ep_rank_start = self.moe_config.ep_rank * self.moe_config.num_local_experts
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1121,6 +1122,7 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod): ...@@ -1121,6 +1122,7 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
w13_bias=layer.w13_bias, w13_bias=layer.w13_bias,
w2_bias=layer.w2_bias, w2_bias=layer.w2_bias,
is_mxfp4=True, is_mxfp4=True,
experts_start_id=ep_rank_start,
) )
def apply( def apply(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment