Unverified Commit f58675bf authored by TianyuLi0's avatar TianyuLi0 Committed by GitHub
Browse files

[CPU] add cpu fused moe pytorch native implementation (#23146)


Signed-off-by: default avatarTianyu Li <tianyu.li@arm.com>
Co-authored-by: default avatarLi, Jiang <jiang1.li@intel.com>
parent 7c04779a
...@@ -3,61 +3,17 @@ ...@@ -3,61 +3,17 @@
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
from torch.nn import functional as F
from vllm import envs from vllm import envs
class IPEXFusedMOE: def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
def __init__(self, layer: torch.nn.Module) -> None: return F.silu(x[..., :d]) * x[..., d:]
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
use_prepack=envs.VLLM_CPU_MOE_PREPACK,
)
def __call__(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
custom_routing_function,
scoring_func,
e_score_correction_bias,
)
class SGLFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
pass
@staticmethod def grouped_topk(
def _grouped_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
...@@ -66,7 +22,7 @@ class SGLFusedMOE: ...@@ -66,7 +22,7 @@ class SGLFusedMOE:
topk_group: int = 0, topk_group: int = 0,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], ( assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch") "Number of tokens mismatch")
...@@ -80,9 +36,6 @@ class SGLFusedMOE: ...@@ -80,9 +36,6 @@ class SGLFusedMOE:
num_token = scores.shape[0] num_token = scores.shape[0]
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use
# biased scores for expert selection but original scores for
# routing weights
original_scores = scores original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0) scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (scores.view(num_token, num_expert_group, group_scores = (scores.view(num_token, num_expert_group,
...@@ -90,22 +43,18 @@ class SGLFusedMOE: ...@@ -90,22 +43,18 @@ class SGLFusedMOE:
else: else:
group_scores = scores.view(num_token, num_expert_group, group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group] -1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
k=topk_group,
dim=-1,
sorted=False)[1] # [n, top_k_group] sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand( score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group, num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
-1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e] float("-inf")) # [n, e]
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids) topk_weights = original_scores.gather(1, topk_ids)
else: else:
topk_weights, topk_ids = torch.topk(tmp_scores, topk_weights, topk_ids = torch.topk(tmp_scores,
...@@ -114,13 +63,12 @@ class SGLFusedMOE: ...@@ -114,13 +63,12 @@ class SGLFusedMOE:
sorted=False) sorted=False)
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
keepdim=True)
return topk_weights, topk_ids.to(torch.int32) return topk_weights, topk_ids.to(torch.int32)
@staticmethod
def _select_experts( def select_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
...@@ -131,13 +79,11 @@ class SGLFusedMOE: ...@@ -131,13 +79,11 @@ class SGLFusedMOE:
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk: if use_grouped_topk:
assert topk_group is not None assert topk_group is not None
assert num_expert_group is not None assert num_expert_group is not None
topk_weights, topk_ids = SGLFusedMOE._grouped_topk( return grouped_topk(hidden_states=hidden_states,
hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=top_k,
renormalize=renormalize, renormalize=renormalize,
...@@ -153,15 +99,62 @@ class SGLFusedMOE: ...@@ -153,15 +99,62 @@ class SGLFusedMOE:
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
if renormalize: if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True) topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
topk_ids = topk_ids.to(torch.int32) return topk_weights, topk_ids.to(torch.int32)
else: else:
topk_weights, topk_ids = custom_routing_function( return custom_routing_function(hidden_states=hidden_states,
hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=top_k,
renormalize=renormalize) renormalize=renormalize)
return topk_weights, topk_ids
class IPEXFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
use_prepack=envs.VLLM_CPU_MOE_PREPACK,
)
def __call__(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
custom_routing_function,
scoring_func,
e_score_correction_bias,
)
class SGLFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
pass
def __call__( def __call__(
self, self,
...@@ -183,7 +176,7 @@ class SGLFusedMOE: ...@@ -183,7 +176,7 @@ class SGLFusedMOE:
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported." assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input assert not apply_router_weight_on_input
topk_weights, topk_ids = SGLFusedMOE._select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
...@@ -213,3 +206,80 @@ class SGLFusedMOE: ...@@ -213,3 +206,80 @@ class SGLFusedMOE:
True, True,
) )
return x return x
class CPUFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
pass
def __call__(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
# Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53
len_experts = global_num_experts
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
layer_w13_weight = layer.w13_weight[i]
layer_w2_weight = layer.w2_weight[i]
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
gate_up = silu_and_mul(gate_up)
expert_out = F.linear(gate_up, layer_w2_weight)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs,
dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (new_x.view(
*topk_ids.shape, -1).type(topk_weights.dtype).mul_(
topk_weights.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
return final_out
...@@ -358,8 +358,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -358,8 +358,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_prepack=True, use_prepack=True,
) )
elif current_platform.is_cpu(): elif current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
from vllm.model_executor.layers.fused_moe import cpu_fused_moe from vllm.model_executor.layers.fused_moe import cpu_fused_moe
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
from vllm.model_executor.layers.utils import ( from vllm.model_executor.layers.utils import (
check_cpu_sgl_kernel) check_cpu_sgl_kernel)
dtype_w13 = layer.w13_weight.dtype dtype_w13 = layer.w13_weight.dtype
...@@ -382,7 +382,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -382,7 +382,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else: else:
layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer)
else: else:
raise NotImplementedError("CPU MOE only supports x86 arch.") layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
def apply( def apply(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment