Unverified Commit 52c03f16 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add activation parameters to fused_moe (#3170)

parent 741fccd7
...@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module): ...@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
): ):
super().__init__() super().__init__()
...@@ -140,6 +141,7 @@ class EPMoE(torch.nn.Module): ...@@ -140,6 +141,7 @@ class EPMoE(torch.nn.Module):
self.num_expert_group = num_expert_group self.num_expert_group = num_expert_group
self.topk_group = topk_group self.topk_group = topk_group
self.correction_bias = correction_bias self.correction_bias = correction_bias
self.activation = activation
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
...@@ -166,6 +168,7 @@ class EPMoE(torch.nn.Module): ...@@ -166,6 +168,7 @@ class EPMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None assert self.quant_method is not None
assert self.activation == "silu"
if self.grouped_gemm_runner is None: if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner( self.grouped_gemm_runner = GroupedGemmRunner(
......
...@@ -8,7 +8,7 @@ from typing import Callable, Optional ...@@ -8,7 +8,7 @@ from typing import Callable, Optional
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -23,6 +23,7 @@ def fused_moe_forward_native( ...@@ -23,6 +23,7 @@ def fused_moe_forward_native(
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
...@@ -41,7 +42,12 @@ def fused_moe_forward_native( ...@@ -41,7 +42,12 @@ def fused_moe_forward_native(
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids] w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1) if activation == "silu":
x1 = F.silu(x1)
elif activation == "gelu":
x1 = F.gelu(x1)
else:
raise ValueError(f"Unsupported activation: {activation=}")
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
...@@ -58,6 +64,7 @@ def moe_forward_native( ...@@ -58,6 +64,7 @@ def moe_forward_native(
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
...@@ -84,6 +91,13 @@ def moe_forward_native( ...@@ -84,6 +91,13 @@ def moe_forward_native(
sorted_tokens = x[idxs // topk_ids.shape[1]] sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy() tokens_per_expert = tokens_per_expert.cpu().numpy()
if activation == "silu":
act = SiluAndMul()
elif activation == "gelu":
act = GeluAndMul()
else:
raise ValueError(f"Unsupported activation: {activation=}")
outputs = [] outputs = []
start_idx = 0 start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert): for i, num_tokens in enumerate(tokens_per_expert):
...@@ -96,7 +110,7 @@ def moe_forward_native( ...@@ -96,7 +110,7 @@ def moe_forward_native(
layer_w2_weight = layer.w2_weight[i] layer_w2_weight = layer.w2_weight[i]
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight) gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
gate_up = SiluAndMul()(gate_up) gate_up = act(gate_up)
expert_out = F.linear(gate_up, layer_w2_weight) expert_out = F.linear(gate_up, layer_w2_weight)
outputs.append(expert_out) outputs.append(expert_out)
start_idx = end_idx start_idx = end_idx
......
...@@ -711,6 +711,7 @@ def inplace_fused_experts( ...@@ -711,6 +711,7 @@ def inplace_fused_experts(
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -726,6 +727,7 @@ def inplace_fused_experts( ...@@ -726,6 +727,7 @@ def inplace_fused_experts(
topk_weights, topk_weights,
topk_ids, topk_ids,
True, True,
activation,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
w1_scale, w1_scale,
...@@ -742,6 +744,7 @@ def inplace_fused_experts_fake( ...@@ -742,6 +744,7 @@ def inplace_fused_experts_fake(
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -767,6 +770,7 @@ def outplace_fused_experts( ...@@ -767,6 +770,7 @@ def outplace_fused_experts(
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -782,6 +786,7 @@ def outplace_fused_experts( ...@@ -782,6 +786,7 @@ def outplace_fused_experts(
topk_weights, topk_weights,
topk_ids, topk_ids,
False, False,
activation,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
w1_scale, w1_scale,
...@@ -798,6 +803,7 @@ def outplace_fused_experts_fake( ...@@ -798,6 +803,7 @@ def outplace_fused_experts_fake(
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -824,6 +830,7 @@ def fused_experts( ...@@ -824,6 +830,7 @@ def fused_experts(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -839,6 +846,7 @@ def fused_experts( ...@@ -839,6 +846,7 @@ def fused_experts(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
activation,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
w1_scale, w1_scale,
...@@ -855,6 +863,7 @@ def fused_experts( ...@@ -855,6 +863,7 @@ def fused_experts(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
activation,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
w1_scale, w1_scale,
...@@ -872,6 +881,7 @@ def fused_experts_impl( ...@@ -872,6 +881,7 @@ def fused_experts_impl(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -986,7 +996,12 @@ def fused_experts_impl( ...@@ -986,7 +996,12 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
) )
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) if activation == "silu":
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
elif activation == "gelu":
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported activation: {activation=}")
invoke_fused_moe_kernel( invoke_fused_moe_kernel(
intermediate_cache2, intermediate_cache2,
...@@ -1042,6 +1057,7 @@ def fused_moe( ...@@ -1042,6 +1057,7 @@ def fused_moe(
topk: int, topk: int,
renormalize: bool, renormalize: bool,
inplace: bool = False, inplace: bool = False,
activation: str = "silu",
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
...@@ -1111,6 +1127,7 @@ def fused_moe( ...@@ -1111,6 +1127,7 @@ def fused_moe(
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=inplace, inplace=inplace,
activation=activation,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale, w1_scale=w1_scale,
......
...@@ -126,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -126,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward( return self.forward(
x=x, x=x,
...@@ -138,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -138,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
activation=activation,
) )
def forward_cuda( def forward_cuda(
...@@ -152,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -152,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
...@@ -169,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -169,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
import ater import ater
from ater.fused_moe import fused_experts_ck from ater.fused_moe import fused_experts_ck
assert activation == "silu", f"{activation=} is not supported."
return fused_experts_ck( return fused_experts_ck(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
...@@ -184,6 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -184,6 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation,
) )
def forward_cpu( def forward_cpu(
...@@ -256,6 +262,7 @@ class FusedMoE(torch.nn.Module): ...@@ -256,6 +262,7 @@ class FusedMoE(torch.nn.Module):
prefix: str = "", prefix: str = "",
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
use_presharded_weights: bool = False, use_presharded_weights: bool = False,
): ):
super().__init__() super().__init__()
...@@ -279,6 +286,7 @@ class FusedMoE(torch.nn.Module): ...@@ -279,6 +286,7 @@ class FusedMoE(torch.nn.Module):
self.topk_group = topk_group self.topk_group = topk_group
self.custom_routing_function = custom_routing_function self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias self.correction_bias = correction_bias
self.activation = activation
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = ( self.quant_method: Optional[QuantizeMethodBase] = (
...@@ -589,6 +597,7 @@ class FusedMoE(torch.nn.Module): ...@@ -589,6 +597,7 @@ class FusedMoE(torch.nn.Module):
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
activation=self.activation,
) )
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
......
...@@ -763,8 +763,8 @@ class Fp8MoEMethod: ...@@ -763,8 +763,8 @@ class Fp8MoEMethod:
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
...@@ -785,6 +785,8 @@ class Fp8MoEMethod: ...@@ -785,6 +785,8 @@ class Fp8MoEMethod:
import ater import ater
from ater.fused_moe import fused_experts_ck from ater.fused_moe import fused_experts_ck
assert activation == "silu", f"{activation=} is not supported."
return fused_experts_ck( return fused_experts_ck(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -815,6 +817,7 @@ class Fp8MoEMethod: ...@@ -815,6 +817,7 @@ class Fp8MoEMethod:
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale=( w1_scale=(
layer.w13_weight_scale_inv layer.w13_weight_scale_inv
......
...@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module): ...@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module):
renormalize=False, renormalize=False,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
activation="gelu",
use_presharded_weights=use_presharded_weights, use_presharded_weights=use_presharded_weights,
) )
......
...@@ -2,8 +2,6 @@ import unittest ...@@ -2,8 +2,6 @@ import unittest
import torch import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
w8a8_block_fp8_matmul, w8a8_block_fp8_matmul,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment