Unverified Commit 8d4ed42a authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

MoE torch compile (#1497)

parent 2854a5ea
from typing import Optional
import torch
from torch.nn import functional as F
def fused_topk_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_weights = F.softmax(gating_output.float(), dim=-1)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
# This is used by the Deepseek-V2 model
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def select_experts_native(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
):
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
else:
topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids
def fused_moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
) -> torch.Tensor:
topk_weights, topk_ids = select_experts_native(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
)
w13_weights = layer.w13_weight[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights)
...@@ -25,6 +25,7 @@ import torch ...@@ -25,6 +25,7 @@ import torch
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.fused_moe.patch import fused_moe_forward_native
from sglang.srt.layers.logits_processor import ( from sglang.srt.layers.logits_processor import (
LogitsMetadata, LogitsMetadata,
LogitsProcessor, LogitsProcessor,
...@@ -41,12 +42,13 @@ if TYPE_CHECKING: ...@@ -41,12 +42,13 @@ if TYPE_CHECKING:
def _to_torch(model: torch.nn.Module, reverse: bool = False): def _to_torch(model: torch.nn.Module, reverse: bool = False):
for sub in model._modules.values(): for sub in model._modules.values():
if isinstance(sub, CustomOp): if isinstance(sub, CustomOp):
# NOTE: FusedMoE torch native implementaiton is not efficient
if "FusedMoE" in sub.__class__.__name__:
continue
if reverse: if reverse:
sub._forward_method = sub.forward_cuda sub._forward_method = sub.forward_cuda
setattr(sub, "is_torch_compile", False) setattr(sub, "is_torch_compile", False)
else:
# NOTE: Temporarily workaround MoE
if "FusedMoE" in sub.__class__.__name__:
sub._forward_method = fused_moe_forward_native
else: else:
sub._forward_method = sub.forward_native sub._forward_method = sub.forward_native
setattr(sub, "is_torch_compile", True) setattr(sub, "is_torch_compile", True)
...@@ -67,7 +69,9 @@ def patch_model( ...@@ -67,7 +69,9 @@ def patch_model(
monkey_patch_vllm_all_gather() monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm backup_ca_comm = tp_group.ca_comm
tp_group.ca_comm = None tp_group.ca_comm = None
yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs") yield torch.compile(
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
)
else: else:
yield model.forward yield model.forward
finally: finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment