Unverified Commit 48c1fa7b authored by jianan-gu's avatar jianan-gu Committed by GitHub
Browse files

[CPU][Llama4] Fix Llama4 MoE inputs with "apply_router_weight_on_input" (#7889)

parent 8aa5ae6b
......@@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp(
model_config = update_intermediate_size(
model_config, "intermediate_size", intermediate_padding_size
)
model_config = update_intermediate_size(
model_config, "intermediate_size_mlp", intermediate_padding_size
)
return model_config
......@@ -93,6 +93,19 @@ def fused_topk_cpu(
return topk_weights, topk_ids
def apply_topk_weights_cpu(need_apply, topk_weights, inputs):
if not need_apply:
return inputs, topk_weights
# TODO: fuse below processing in fused_experts_cpu kernel
inputs = inputs * topk_weights.to(inputs.dtype)
topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32
) # clear topk_weights as already applied
return inputs, topk_weights
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
......
......@@ -1005,6 +1005,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
......
......@@ -344,9 +344,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported."
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.moe.topk import (
select_experts,
apply_topk_weights_cpu,
)
topk_weights, topk_ids = select_experts(
hidden_states=x,
......@@ -361,8 +364,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
)
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
......
......@@ -497,6 +497,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
)
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment