Unverified Commit 52c03f16 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add activation parameters to fused_moe (#3170)

parent 741fccd7
......@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
):
super().__init__()
......@@ -140,6 +141,7 @@ class EPMoE(torch.nn.Module):
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.activation = activation
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
......@@ -166,6 +168,7 @@ class EPMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None
assert self.activation == "silu"
if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
......
......@@ -8,7 +8,7 @@ from typing import Callable, Optional
import torch
from torch.nn import functional as F
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.topk import select_experts
......@@ -23,6 +23,7 @@ def fused_moe_forward_native(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
......@@ -41,7 +42,12 @@ def fused_moe_forward_native(
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
if activation == "silu":
x1 = F.silu(x1)
elif activation == "gelu":
x1 = F.gelu(x1)
else:
raise ValueError(f"Unsupported activation: {activation=}")
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
......@@ -58,6 +64,7 @@ def moe_forward_native(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
......@@ -84,6 +91,13 @@ def moe_forward_native(
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
if activation == "silu":
act = SiluAndMul()
elif activation == "gelu":
act = GeluAndMul()
else:
raise ValueError(f"Unsupported activation: {activation=}")
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
......@@ -96,7 +110,7 @@ def moe_forward_native(
layer_w2_weight = layer.w2_weight[i]
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
gate_up = SiluAndMul()(gate_up)
gate_up = act(gate_up)
expert_out = F.linear(gate_up, layer_w2_weight)
outputs.append(expert_out)
start_idx = end_idx
......
......@@ -711,6 +711,7 @@ def inplace_fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
......@@ -726,6 +727,7 @@ def inplace_fused_experts(
topk_weights,
topk_ids,
True,
activation,
use_fp8_w8a8,
use_int8_w8a16,
w1_scale,
......@@ -742,6 +744,7 @@ def inplace_fused_experts_fake(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
......@@ -767,6 +770,7 @@ def outplace_fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
......@@ -782,6 +786,7 @@ def outplace_fused_experts(
topk_weights,
topk_ids,
False,
activation,
use_fp8_w8a8,
use_int8_w8a16,
w1_scale,
......@@ -798,6 +803,7 @@ def outplace_fused_experts_fake(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
......@@ -824,6 +830,7 @@ def fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
......@@ -839,6 +846,7 @@ def fused_experts(
w2,
topk_weights,
topk_ids,
activation,
use_fp8_w8a8,
use_int8_w8a16,
w1_scale,
......@@ -855,6 +863,7 @@ def fused_experts(
w2,
topk_weights,
topk_ids,
activation,
use_fp8_w8a8,
use_int8_w8a16,
w1_scale,
......@@ -872,6 +881,7 @@ def fused_experts_impl(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
......@@ -986,7 +996,12 @@ def fused_experts_impl(
block_shape=block_shape,
)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
if activation == "silu":
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
elif activation == "gelu":
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported activation: {activation=}")
invoke_fused_moe_kernel(
intermediate_cache2,
......@@ -1042,6 +1057,7 @@ def fused_moe(
topk: int,
renormalize: bool,
inplace: bool = False,
activation: str = "silu",
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
......@@ -1111,6 +1127,7 @@ def fused_moe(
topk_weights,
topk_ids,
inplace=inplace,
activation=activation,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
......
......@@ -126,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
return self.forward(
x=x,
......@@ -138,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
activation=activation,
)
def forward_cuda(
......@@ -152,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
......@@ -169,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
import ater
from ater.fused_moe import fused_experts_ck
assert activation == "silu", f"{activation=} is not supported."
return fused_experts_ck(
hidden_states=x,
w1=layer.w13_weight,
......@@ -184,6 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
)
def forward_cpu(
......@@ -256,6 +262,7 @@ class FusedMoE(torch.nn.Module):
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
use_presharded_weights: bool = False,
):
super().__init__()
......@@ -279,6 +286,7 @@ class FusedMoE(torch.nn.Module):
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.activation = activation
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
......@@ -589,6 +597,7 @@ class FusedMoE(torch.nn.Module):
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
activation=self.activation,
)
if self.reduce_results and self.tp_size > 1:
......
......@@ -763,8 +763,8 @@ class Fp8MoEMethod:
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
......@@ -785,6 +785,8 @@ class Fp8MoEMethod:
import ater
from ater.fused_moe import fused_experts_ck
assert activation == "silu", f"{activation=} is not supported."
return fused_experts_ck(
x,
layer.w13_weight,
......@@ -815,6 +817,7 @@ class Fp8MoEMethod:
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
......
......@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module):
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
activation="gelu",
use_presharded_weights=use_presharded_weights,
)
......
......@@ -2,8 +2,6 @@ import unittest
import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment