Unverified Commit e835a500 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Reorg moe code (#2563)

parent 23e5e50f
...@@ -5,7 +5,9 @@ import triton ...@@ -5,7 +5,9 @@ import triton
from torch.nn import functional as F from torch.nn import functional as F
from transformers import AutoConfig from transformers import AutoConfig
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_triton,
)
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
......
...@@ -5,7 +5,9 @@ import triton ...@@ -5,7 +5,9 @@ import triton
from transformers import AutoConfig from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang,
)
def get_model_config(model_name: str, tp_size: int): def get_model_config(model_name: str, tp_size: int):
......
...@@ -11,7 +11,7 @@ import triton ...@@ -11,7 +11,7 @@ import triton
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig from transformers import AutoConfig
from sglang.srt.layers.fused_moe_triton.fused_moe import ( from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe, fused_moe,
get_config_dtype_str, get_config_dtype_str,
get_config_file_name, get_config_file_name,
...@@ -97,7 +97,7 @@ def benchmark_config( ...@@ -97,7 +97,7 @@ def benchmark_config(
input_gating.copy_(gating_output[i]) input_gating.copy_(gating_output[i])
def run(): def run():
from sglang.srt.layers.fused_moe_triton import override_config from sglang.srt.layers.moe.fused_moe_triton import override_config
with override_config(config): with override_config(config):
fused_moe( fused_moe(
......
...@@ -12,15 +12,15 @@ from vllm.model_executor.custom_op import CustomOp ...@@ -12,15 +12,15 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton, grouped_gemm_triton,
post_reorder_triton_kernel, post_reorder_triton_kernel,
pre_reorder_triton_kernel, pre_reorder_triton_kernel,
run_moe_ep_preproess, run_moe_ep_preproess,
silu_and_mul_triton_kernel, silu_and_mul_triton_kernel,
) )
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
...@@ -113,6 +113,7 @@ class EPMoE(torch.nn.Module): ...@@ -113,6 +113,7 @@ class EPMoE(torch.nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
): ):
super().__init__() super().__init__()
...@@ -138,6 +139,7 @@ class EPMoE(torch.nn.Module): ...@@ -138,6 +139,7 @@ class EPMoE(torch.nn.Module):
assert num_expert_group is not None and topk_group is not None assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group self.num_expert_group = num_expert_group
self.topk_group = topk_group self.topk_group = topk_group
self.correction_bias = correction_bias
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
...@@ -170,13 +172,15 @@ class EPMoE(torch.nn.Module): ...@@ -170,13 +172,15 @@ class EPMoE(torch.nn.Module):
hidden_states.device, use_flashinfer=False # TODO: use flashinfer hidden_states.device, use_flashinfer=False # TODO: use flashinfer
) )
topk_weights, topk_ids = self.select_experts( topk_weights, topk_ids = select_experts(
hidden_states, hidden_states=hidden_states,
router_logits, router_logits=router_logits,
self.top_k, top_k=self.top_k,
self.renormalize, use_grouped_topk=self.use_grouped_topk,
self.topk_group, renormalize=self.renormalize,
self.num_expert_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
) )
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
...@@ -297,35 +301,6 @@ class EPMoE(torch.nn.Module): ...@@ -297,35 +301,6 @@ class EPMoE(torch.nn.Module):
) )
return output return output
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
):
if self.use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
else:
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids.to(torch.int32)
@classmethod @classmethod
def make_expert_params_mapping( def make_expert_params_mapping(
cls, cls,
......
"""
Torch-native implementation for FusedMoE. This is used for torch.compile.
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
"""
from typing import Callable, Optional
import torch
from torch.nn import functional as F
from sglang.srt.layers.moe.topk import select_experts
def fused_moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
torch_native=True,
)
w13_weights = layer.w13_weight[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import sglang.srt.layers.fused_moe_triton.fused_moe # noqa import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa
from sglang.srt.layers.fused_moe_triton.fused_moe import ( from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_experts, fused_experts,
fused_topk,
get_config_file_name, get_config_file_name,
grouped_topk,
) )
from sglang.srt.layers.fused_moe_triton.layer import ( from sglang.srt.layers.moe.fused_moe_triton.layer import (
FusedMoE, FusedMoE,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
...@@ -37,8 +35,6 @@ __all__ = [ ...@@ -37,8 +35,6 @@ __all__ = [
"override_config", "override_config",
"get_config", "get_config",
"fused_moe", "fused_moe",
"fused_topk",
"fused_experts", "fused_experts",
"get_config_file_name", "get_config_file_name",
"grouped_topk",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment