Unverified Commit e3ab93c8 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[CPU] Refactor CPU fused MOE (#30531)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent fc2ae6d6
...@@ -2919,6 +2919,42 @@ def cpu_gemm_wna16( ...@@ -2919,6 +2919,42 @@ def cpu_gemm_wna16(
return output return output
def cpu_prepack_moe_weight(
weight: torch.Tensor,
isa: str,
) -> torch.Tensor:
output = torch.empty_like(weight)
torch.ops._C.prepack_moe_weight(weight, output, isa)
return output
def cpu_fused_moe(
input: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
act: str,
isa: str,
) -> torch.Tensor:
output = torch.empty_like(input)
torch.ops._C.cpu_fused_moe(
output,
input,
w13,
w2,
w13_bias,
w2_bias,
topk_weights,
topk_ids,
act,
isa,
)
return output
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"): if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):
@register_fake("_qutlass_C::matmul_mxf4_bf16_tn") @register_fake("_qutlass_C::matmul_mxf4_bf16_tn")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import weakref
from collections.abc import Callable from collections.abc import Callable
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul
from vllm.model_executor.layers.quantization.utils.layer_utils import replace_parameter
from vllm.utils.torch_utils import direct_register_custom_op
_CPU_MOE_LAYER_CACHE = {}
_CPU_MOE_ACT = {
"silu": SiluAndMul(),
"swigluoai": SwigluOAIAndMul(),
}
def grouped_topk( def grouped_topk(
...@@ -174,8 +184,105 @@ class SGLFusedMOE: ...@@ -174,8 +184,105 @@ class SGLFusedMOE:
class CPUFusedMOE: class CPUFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None: def __init__(self, layer: torch.nn.Module) -> None:
use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported() use_grouped_gemm, isa = self.check_grouped_gemm(layer)
self.isa = isa
if use_grouped_gemm:
self.forward_method = self.forward_grouped_gemm
self.init_moe_grouped_gemm(layer=layer)
else:
self.forward_method = self.forward_torch
self.init_moe_torch(layer=layer)
def __call__(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation in _CPU_MOE_ACT, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
return self.forward_method(
layer,
x,
topk_weights,
topk_ids,
activation,
global_num_experts,
)
def check_grouped_gemm(
self,
layer: torch.nn.Module,
) -> tuple[bool, str]:
if not hasattr(torch.ops._C, "prepack_moe_weight"):
return False, "none"
dtype = layer.w13_weight.dtype
w13_input_size = layer.w13_weight.size(2)
w13_output_size = layer.w13_weight.size(1)
w2_input_size = layer.w2_weight.size(2)
w2_output_size = layer.w2_weight.size(1)
if not (w13_output_size % 32 == 0 and w2_output_size % 32 == 0):
return False, "none"
supports_amx = torch._C._cpu._is_amx_tile_supported()
if (
supports_amx
and dtype == torch.bfloat16
and w13_input_size % 32 == 0
and w2_input_size % 32 == 0
):
return True, "amx"
if supports_amx:
return False, "none"
return True, "vec"
def init_moe_grouped_gemm(
self,
layer: torch.nn.Module,
) -> None:
new_w13 = cpu_prepack_moe_weight(layer.w13_weight, self.isa)
replace_parameter(layer, "w13_weight", new_w13)
new_w2 = cpu_prepack_moe_weight(layer.w2_weight, self.isa)
replace_parameter(layer, "w2_weight", new_w2)
def init_moe_torch(
self,
layer: torch.nn.Module,
) -> None:
use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported()
num_experts = layer.w13_weight.size(0) num_experts = layer.w13_weight.size(0)
has_w13_bias = hasattr(layer, "w13_bias") has_w13_bias = hasattr(layer, "w13_bias")
has_w2_bias = hasattr(layer, "w2_bias") has_w2_bias = hasattr(layer, "w2_bias")
...@@ -208,49 +315,69 @@ class CPUFusedMOE: ...@@ -208,49 +315,69 @@ class CPUFusedMOE:
layer.down_linear.append( layer.down_linear.append(
lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b) lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b)
) )
if use_onednn_mm: # remove weight if use_onednn_mm: # remove weight
layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
self.act_to_impl = { _CPU_MOE_LAYER_CACHE[id(layer)] = weakref.ref(layer)
"silu": SiluAndMul(),
"swigluoai": SwigluOAIAndMul(),
}
def __call__( def forward_grouped_gemm(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, input: torch.Tensor,
use_grouped_topk: bool, topk_weights: torch.Tensor,
top_k: int, topk_ids: torch.Tensor,
router_logits: torch.Tensor, activation: str,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation in self.act_to_impl, f"{activation} is not supported." output = cpu_fused_moe(
assert not apply_router_weight_on_input input,
topk_weights, topk_ids = select_experts( layer.w13_weight,
hidden_states=x, layer.w2_weight,
router_logits=router_logits, getattr(layer, "w13_bias", None),
use_grouped_topk=use_grouped_topk, getattr(layer, "w2_bias", None),
top_k=top_k, topk_weights,
renormalize=renormalize, topk_ids,
topk_group=topk_group, activation,
num_expert_group=num_expert_group, self.isa,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
) )
return output
def forward_torch(
self,
layer: torch.nn.Module,
input: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int = -1,
) -> torch.Tensor:
output = torch.empty_like(input)
layer_id = id(layer)
torch.ops.vllm.cpu_fused_moe_torch(
layer_id,
output,
input,
topk_weights,
topk_ids,
activation,
global_num_experts,
)
return output
def cpu_fused_moe_torch(
layer_id: int,
output: torch.Tensor,
input: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int = -1,
) -> None:
layer = _CPU_MOE_LAYER_CACHE[layer_id]()
# Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53 # Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53
len_experts = global_num_experts len_experts = global_num_experts
...@@ -260,7 +387,7 @@ class CPUFusedMOE: ...@@ -260,7 +387,7 @@ class CPUFusedMOE:
tokens_per_expert = cnts.sum(dim=0) tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort() idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]] sorted_tokens = input[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy() tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = [] outputs = []
...@@ -272,9 +399,9 @@ class CPUFusedMOE: ...@@ -272,9 +399,9 @@ class CPUFusedMOE:
continue continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx] tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
gate_up = layer.gate_up_linear[i](tokens_for_this_expert) gate_up = layer.gate_up_linear[i](tokens_for_this_expert) # type: ignore
gate_up = self.act_to_impl[activation].forward_native(gate_up) gate_up = _CPU_MOE_ACT[activation].forward_native(gate_up)
expert_out = layer.down_linear[i](gate_up) expert_out = layer.down_linear[i](gate_up) # type: ignore
outputs.append(expert_out) outputs.append(expert_out)
start_idx = end_idx start_idx = end_idx
...@@ -289,4 +416,11 @@ class CPUFusedMOE: ...@@ -289,4 +416,11 @@ class CPUFusedMOE:
.sum(dim=1) .sum(dim=1)
.type(new_x.dtype) .type(new_x.dtype)
) )
return final_out output.copy_(final_out)
direct_register_custom_op(
op_name="cpu_fused_moe_torch",
op_func=cpu_fused_moe_torch,
mutates_args=["output"],
)
...@@ -1726,9 +1726,10 @@ class FusedMoE(CustomOp): ...@@ -1726,9 +1726,10 @@ class FusedMoE(CustomOp):
return states return states
if self.shared_experts is None: if self.shared_experts is None:
if current_platform.is_tpu(): if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we # TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op. # will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
fused_output = self.forward_impl(hidden_states, router_logits) fused_output = self.forward_impl(hidden_states, router_logits)
assert not isinstance(fused_output, tuple) assert not isinstance(fused_output, tuple)
else: else:
...@@ -1744,9 +1745,10 @@ class FusedMoE(CustomOp): ...@@ -1744,9 +1745,10 @@ class FusedMoE(CustomOp):
else: else:
return reduce_output(fused_output)[..., :og_hidden_states] return reduce_output(fused_output)[..., :og_hidden_states]
else: else:
if current_platform.is_tpu(): if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we # TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op. # will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
shared_output, fused_output = self.forward_impl( shared_output, fused_output = self.forward_impl(
hidden_states, router_logits hidden_states, router_logits
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment