"examples/pytorch/vscode:/vscode.git/clone" did not exist on "701b746b82210a23a8db7b87af080a3a9ec28493"
Unverified Commit 48c1fa7b authored by jianan-gu's avatar jianan-gu Committed by GitHub
Browse files

[CPU][Llama4] Fix Llama4 MoE inputs with "apply_router_weight_on_input" (#7889)

parent 8aa5ae6b
...@@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp( ...@@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp(
model_config = update_intermediate_size( model_config = update_intermediate_size(
model_config, "intermediate_size", intermediate_padding_size model_config, "intermediate_size", intermediate_padding_size
) )
model_config = update_intermediate_size(
model_config, "intermediate_size_mlp", intermediate_padding_size
)
return model_config return model_config
...@@ -93,6 +93,19 @@ def fused_topk_cpu( ...@@ -93,6 +93,19 @@ def fused_topk_cpu(
return topk_weights, topk_ids return topk_weights, topk_ids
def apply_topk_weights_cpu(need_apply, topk_weights, inputs):
if not need_apply:
return inputs, topk_weights
# TODO: fuse below processing in fused_experts_cpu kernel
inputs = inputs * topk_weights.to(inputs.dtype)
topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32
) # clear topk_weights as already applied
return inputs, topk_weights
def fused_topk( def fused_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
......
...@@ -1005,6 +1005,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1005,6 +1005,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu( return torch.ops.sgl_kernel.fused_experts_cpu(
x, x,
layer.w13_weight, layer.w13_weight,
......
...@@ -344,9 +344,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -344,9 +344,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported." assert activation == "silu", f"activation = {activation} is not supported."
if use_intel_amx_backend(layer) and not apply_router_weight_on_input: if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import (
select_experts,
apply_topk_weights_cpu,
)
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
...@@ -361,8 +364,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -361,8 +364,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
)
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
return torch.ops.sgl_kernel.fused_experts_cpu( return torch.ops.sgl_kernel.fused_experts_cpu(
x, x,
layer.w13_weight, layer.w13_weight,
......
...@@ -497,6 +497,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ...@@ -497,6 +497,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
) )
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu( return torch.ops.sgl_kernel.fused_experts_cpu(
x, x,
layer.w13_weight, layer.w13_weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment