Unverified Commit 15ad6c90 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[1/N] MoE Refactor: refactor `select_experts` (#7966)

parent cfab0ff6
...@@ -29,15 +29,18 @@ class CustomOp(nn.Module): ...@@ -29,15 +29,18 @@ class CustomOp(nn.Module):
self._original_forward_method = self._forward_method self._original_forward_method = self._forward_method
# NOTE: Temporarily workaround MoE # NOTE: Temporarily workaround MoE
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs=1
if "FusedMoE" in self.__class__.__name__: if "FusedMoE" in self.__class__.__name__:
if num_tokens == 1: if num_tokens == 1:
from sglang.srt.layers.moe.fused_moe_native import ( from sglang.srt.layers.moe.fused_moe_native import (
fused_moe_forward_native, fused_moe_forward_native,
) )
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
self._forward_method = fused_moe_forward_native self._forward_method = fused_moe_forward_native
elif "TopK" in self.__class__.__name__:
if num_tokens == 1:
self._forward_method = self.forward_native
else: else:
self._forward_method = self.forward_native self._forward_method = self.forward_native
self.is_torch_compile = True self.is_torch_compile = True
......
...@@ -756,7 +756,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -756,7 +756,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias: bool = True, bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional["QuantizationConfig"] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
tp_rank: Optional[int] = None, tp_rank: Optional[int] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
......
import logging import logging
from typing import Callable, List, Optional, Tuple from typing import List, Optional, Tuple
import einops
import torch import torch
from torch.nn import Module
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather, ep_gather,
ep_scatter, ep_scatter,
...@@ -28,7 +24,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -28,7 +24,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
tma_align_input_scale, tma_align_input_scale,
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
...@@ -162,16 +158,9 @@ class EPMoE(torch.nn.Module): ...@@ -162,16 +158,9 @@ class EPMoE(torch.nn.Module):
intermediate_size: int, intermediate_size: int,
layer_id: int, layer_id: int,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_per_token_if_dynamic: bool = True, use_per_token_if_dynamic: bool = True,
...@@ -189,24 +178,12 @@ class EPMoE(torch.nn.Module): ...@@ -189,24 +178,12 @@ class EPMoE(torch.nn.Module):
self.layer_id = layer_id self.layer_id = layer_id
self.num_experts = num_experts self.num_experts = num_experts
assert self.num_experts % self.tp_size == 0 assert self.num_experts % self.tp_size == 0
assert (
num_fused_shared_experts == 0
), "num_fused_shared_experts is not supported in EP"
self.num_fused_shared_experts = num_fused_shared_experts
self.num_experts_per_partition, self.expert_map = self.determine_expert_map() self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
self.start_expert_id = self.tp_rank * self.num_experts_per_partition self.start_expert_id = self.tp_rank * self.num_experts_per_partition
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
self.top_k = top_k self.top_k = top_k
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.custom_routing_function = custom_routing_function
self.activation = activation self.activation = activation
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
self.use_per_token_if_dynamic = use_per_token_if_dynamic self.use_per_token_if_dynamic = use_per_token_if_dynamic
...@@ -311,33 +288,24 @@ class EPMoE(torch.nn.Module): ...@@ -311,33 +288,24 @@ class EPMoE(torch.nn.Module):
) )
return (local_num_experts, expert_map) return (local_num_experts, expert_map)
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm(hidden_states, router_logits) return self.forward_deepgemm(hidden_states, topk_output)
else: else:
return self.forward_normal(hidden_states, router_logits) return self.forward_normal(hidden_states, topk_output)
def forward_deepgemm( def forward_deepgemm(
self, hidden_states: torch.Tensor, router_logits: torch.Tensor self,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
): ):
assert self.quant_method is not None assert self.quant_method is not None
assert self.activation == "silu" assert self.activation == "silu"
hidden_states_shape = hidden_states.shape hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device hidden_states_device = hidden_states.device
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states, topk_weights, topk_ids, _ = topk_output
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
)
if not self.use_block_quant: if not self.use_block_quant:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
...@@ -469,8 +437,10 @@ class EPMoE(torch.nn.Module): ...@@ -469,8 +437,10 @@ class EPMoE(torch.nn.Module):
) )
return output return output
def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert self.quant_method is not None assert self.quant_method is not None
topk_weights, topk_ids, _ = topk_output
hidden_states_shape = hidden_states.shape hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device hidden_states_device = hidden_states.device
...@@ -481,23 +451,6 @@ class EPMoE(torch.nn.Module): ...@@ -481,23 +451,6 @@ class EPMoE(torch.nn.Module):
use_per_token_if_dynamic=self.use_per_token_if_dynamic, use_per_token_if_dynamic=self.use_per_token_if_dynamic,
) )
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
if self.use_w4afp8: if self.use_w4afp8:
local_topk_ids = topk_ids local_topk_ids = topk_ids
if self.expert_map is not None: if self.expert_map is not None:
...@@ -916,16 +869,9 @@ class DeepEPMoE(EPMoE): ...@@ -916,16 +869,9 @@ class DeepEPMoE(EPMoE):
intermediate_size: int, intermediate_size: int,
layer_id: int, layer_id: int,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
deepep_mode: DeepEPMode = DeepEPMode.auto, deepep_mode: DeepEPMode = DeepEPMode.auto,
...@@ -937,16 +883,9 @@ class DeepEPMoE(EPMoE): ...@@ -937,16 +883,9 @@ class DeepEPMoE(EPMoE):
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
layer_id=layer_id, layer_id=layer_id,
params_dtype=params_dtype, params_dtype=params_dtype,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
topk_group=topk_group,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
prefix=prefix, prefix=prefix,
correction_bias=correction_bias,
custom_routing_function=custom_routing_function,
activation=activation, activation=activation,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
......
...@@ -9,21 +9,14 @@ import torch ...@@ -9,21 +9,14 @@ import torch
from torch.nn import functional as F from torch.nn import functional as F
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKOutput
def fused_moe_forward_native( def fused_moe_forward_native(
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool, topk_output: TopKOutput,
top_k: int, *,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
...@@ -34,20 +27,7 @@ def fused_moe_forward_native( ...@@ -34,20 +27,7 @@ def fused_moe_forward_native(
if apply_router_weight_on_input: if apply_router_weight_on_input:
raise NotImplementedError() raise NotImplementedError()
topk_weights, topk_ids = select_experts( topk_weights, topk_ids, _ = topk_output
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
torch_native=True,
)
w13_weights = layer.w13_weight[topk_ids] w13_weights = layer.w13_weight[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
...@@ -67,15 +47,8 @@ def fused_moe_forward_native( ...@@ -67,15 +47,8 @@ def fused_moe_forward_native(
def moe_forward_native( def moe_forward_native(
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool, topk_output: TopKOutput,
top_k: int, *,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
...@@ -86,20 +59,7 @@ def moe_forward_native( ...@@ -86,20 +59,7 @@ def moe_forward_native(
if apply_router_weight_on_input: if apply_router_weight_on_input:
raise NotImplementedError() raise NotImplementedError()
topk_weights, topk_ids = select_experts( topk_weights, topk_ids, _ = topk_output
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
torch_native=True,
routed_scaling_factor=routed_scaling_factor,
)
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589 # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
len_experts = layer.num_experts len_experts = layer.num_experts
......
...@@ -6,13 +6,13 @@ import functools ...@@ -6,13 +6,13 @@ import functools
import json import json
import logging import logging
import os import os
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
scaled_fp8_quant, scaled_fp8_quant,
...@@ -1328,8 +1328,7 @@ def fused_experts( ...@@ -1328,8 +1328,7 @@ def fused_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_output: TopKOutput,
topk_ids: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
...@@ -1348,7 +1347,7 @@ def fused_experts( ...@@ -1348,7 +1347,7 @@ def fused_experts(
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
): ):
topk_weights, topk_ids, _ = topk_output
if inplace: if inplace:
assert not no_combine, "no combine + inplace makes no sense" assert not no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts( torch.ops.sglang.inplace_fused_experts(
...@@ -1732,17 +1731,10 @@ def fused_moe( ...@@ -1732,17 +1731,10 @@ def fused_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
gating_output: torch.Tensor, topk_output: TopKOutput,
topk: int,
renormalize: bool,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
...@@ -1766,16 +1758,9 @@ def fused_moe( ...@@ -1766,16 +1758,9 @@ def fused_moe(
- hidden_states (torch.Tensor): The input tensor to the MoE layer. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation - topk_output (TopKOutput): The top-k output of the experts.
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. - inplace (bool): If True, perform the operation in-place.
Defaults to False. Defaults to False.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseek V2/V3/R1 series models use grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False. products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
...@@ -1799,28 +1784,12 @@ def fused_moe( ...@@ -1799,28 +1784,12 @@ def fused_moe(
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=gating_output,
use_grouped_topk=use_grouped_topk,
top_k=topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts( return fused_experts(
hidden_states, hidden_states,
w1, w1,
w2, w2,
topk_weights, topk_output,
topk_ids,
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import logging import logging
from enum import Enum from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -11,6 +11,7 @@ from sglang.srt.distributed import ( ...@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
...@@ -59,22 +60,15 @@ class FusedMoE(torch.nn.Module): ...@@ -59,22 +60,15 @@ class FusedMoE(torch.nn.Module):
def __init__( def __init__(
self, self,
num_experts: int, num_experts: int,
top_k: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
top_k: Optional[int] = None,
layer_id: Optional[int] = None, layer_id: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False, reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_presharded_weights: bool = False, use_presharded_weights: bool = False,
...@@ -89,6 +83,7 @@ class FusedMoE(torch.nn.Module): ...@@ -89,6 +83,7 @@ class FusedMoE(torch.nn.Module):
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.top_k = top_k
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.tp_size = ( self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size() tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
...@@ -126,19 +121,9 @@ class FusedMoE(torch.nn.Module): ...@@ -126,19 +121,9 @@ class FusedMoE(torch.nn.Module):
self.ep_rank = 0 self.ep_rank = 0
self.local_num_experts = num_experts self.local_num_experts = num_experts
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
self.top_k = top_k
assert intermediate_size % self.tp_size == 0 assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.activation = activation self.activation = activation
self.apply_router_weight_on_input = apply_router_weight_on_input self.apply_router_weight_on_input = apply_router_weight_on_input
self.use_presharded_weights = use_presharded_weights self.use_presharded_weights = use_presharded_weights
...@@ -562,22 +547,14 @@ class FusedMoE(torch.nn.Module): ...@@ -562,22 +547,14 @@ class FusedMoE(torch.nn.Module):
) )
return return
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert self.quant_method is not None assert self.quant_method is not None
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
x=hidden_states, x=hidden_states,
router_logits=router_logits, topk_output=topk_output,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
activation=self.activation, activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
......
...@@ -12,12 +12,15 @@ ...@@ -12,12 +12,15 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
from __future__ import annotations
import math import math
from typing import Callable, Optional from typing import TYPE_CHECKING, Callable, NamedTuple, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.custom_op import CustomOp
from sglang.srt.eplb import expert_location_dispatch from sglang.srt.eplb import expert_location_dispatch
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location_dispatch import ( from sglang.srt.eplb.expert_location_dispatch import (
...@@ -52,6 +55,168 @@ if _use_aiter: ...@@ -52,6 +55,168 @@ if _use_aiter:
except ImportError: except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
if _is_npu:
import torch_npu
class TopKOutput(NamedTuple):
topk_weights: torch.Tensor
topk_ids: torch.Tensor
router_logits: torch.Tensor
class TopK(CustomOp):
# TODO(ch-wan): support triton_kernels
def __init__(
self,
top_k: int,
*,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
renormalize: bool = True,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
super().__init__()
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.top_k = top_k
self.use_grouped_topk = use_grouped_topk
self.renormalize = renormalize
self.topk_group = topk_group
self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
def forward_native(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
torch_native = True
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
torch_native = False
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def forward_cpu(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def forward_npu(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
return torch_npu.npu_moe_gating_top_k(
router_logits,
k=self.top_k,
bias=self.correction_bias,
k_group=self.topk_group,
group_count=self.num_expert_group,
group_select_mode=1,
renorm=0,
norm_type=1,
routed_scaling_factor=1,
eps=float(1e-20),
)
else:
torch_native = True
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def fused_topk_torch_native( def fused_topk_torch_native(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -436,8 +601,9 @@ def select_experts( ...@@ -436,8 +601,9 @@ def select_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
use_grouped_topk: bool, *,
renormalize: bool, use_grouped_topk: bool = False,
renormalize: bool = False,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
...@@ -447,7 +613,7 @@ def select_experts( ...@@ -447,7 +613,7 @@ def select_experts(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
): ) -> TopKOutput:
router_logits, correction_bias = ( router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs( expert_location_dispatch.transform_select_experts_inputs(
router_logits=router_logits, router_logits=router_logits,
...@@ -522,4 +688,4 @@ def select_experts( ...@@ -522,4 +688,4 @@ def select_experts(
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
return topk_weights, topk_ids return TopKOutput(topk_weights, topk_ids, router_logits)
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
from __future__ import annotations
import builtins import builtins
import inspect import inspect
from typing import Callable, Dict, Optional, Type, Union from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
import torch import torch
...@@ -65,6 +67,9 @@ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config ...@@ -65,6 +67,9 @@ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
# Base quantization methods that don't depend on vllm # Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"fp8": Fp8Config, "fp8": Fp8Config,
...@@ -186,15 +191,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): ...@@ -186,15 +191,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
...@@ -208,20 +206,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): ...@@ -208,20 +206,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
"self": self, "self": self,
"layer": layer, "layer": layer,
"x": x, "x": x,
"router_logits": router_logits, "topk_output": topk_output,
"top_k": top_k,
"renormalize": renormalize,
"use_grouped_topk": use_grouped_topk,
"topk_group": topk_group,
"num_expert_group": num_expert_group,
"custom_routing_function": custom_routing_function,
} }
if correction_bias is not None:
if not has_correction_bias:
raise ValueError(
"Please increase the version of your vllm. Try `pip install vllm==0.9.0.1`"
)
kwargs["e_score_correction_bias"] = correction_bias
return original_apply(**kwargs) return original_apply(**kwargs)
setattr(class_obj, "apply", new_apply) setattr(class_obj, "apply", new_apply)
......
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
import warnings import warnings
from typing import Any, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch import torch
...@@ -33,6 +33,9 @@ from sglang.srt.layers.quantization.scalar_type import scalar_types ...@@ -33,6 +33,9 @@ from sglang.srt.layers.quantization.scalar_type import scalar_types
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import replace_parameter from sglang.srt.layers.quantization.utils import replace_parameter
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -737,45 +740,19 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -737,45 +740,19 @@ class AWQMoEMethod(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
assert (
scoring_func == "softmax"
), "Only softmax score func is supported for now."
# The input must currently be float16 # The input must currently be float16
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.half() x = x.half()
topk_weights, topk_ids = select_experts( topk_weights, topk_ids, router_logits = topk_output
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_marlin_moe( return fused_marlin_moe(
x, x,
......
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py
from __future__ import annotations
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
import torch import torch
from torch import nn from torch import nn
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
class QuantizeMethodBase(ABC): class QuantizeMethodBase(ABC):
"""Base class for different quantized methods.""" """Base class for different quantized methods."""
...@@ -88,19 +92,22 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -88,19 +92,22 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
raise NotImplementedError() raise NotImplementedError
@abstractmethod @abstractmethod
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool, activation: str = "silu",
use_grouped_topk: bool, apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError() raise NotImplementedError
class QuantizationConfig(ABC): class QuantizationConfig(ABC):
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
...@@ -21,6 +21,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod ...@@ -21,6 +21,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -344,15 +347,8 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): ...@@ -344,15 +347,8 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
...@@ -360,30 +356,13 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): ...@@ -360,30 +356,13 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
# Expert fusion with INT8 quantization # Expert fusion with INT8 quantization
return fused_experts( return fused_experts(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_output=topk_output,
topk_ids=topk_ids,
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
......
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import enum import enum
import logging import logging
from enum import Enum from enum import Enum
from typing import Callable, List, Optional from typing import TYPE_CHECKING, List, Optional
import torch import torch
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
...@@ -20,6 +22,12 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -20,6 +22,12 @@ from sglang.srt.layers.quantization.utils import (
) )
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
...@@ -51,7 +59,7 @@ __all__ = [ ...@@ -51,7 +59,7 @@ __all__ = [
] ]
class CompressedTensorsMoEMethod: class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if cls is CompressedTensorsMoEMethod: if cls is CompressedTensorsMoEMethod:
return super().__new__(cls) return super().__new__(cls)
...@@ -59,7 +67,7 @@ class CompressedTensorsMoEMethod: ...@@ -59,7 +67,7 @@ class CompressedTensorsMoEMethod:
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: CompressedTensorsConfig,
) -> "CompressedTensorsMoEMethod": ) -> "CompressedTensorsMoEMethod":
# TODO: @dsikka: refactor this to use schemes as other kernels # TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
...@@ -82,9 +90,7 @@ class CompressedTensorsMoEMethod: ...@@ -82,9 +90,7 @@ class CompressedTensorsMoEMethod:
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(self, quant_config: CompressedTensorsConfig):
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get( self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
...@@ -270,47 +276,21 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -270,47 +276,21 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
apply_router_weight_on_input: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import fused_experts from sglang.srt.layers.moe.fused_moe_triton import fused_experts
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts( return fused_experts(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_output=topk_output,
topk_ids=topk_ids,
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
...@@ -327,9 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -327,9 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(self, quant_config: CompressedTensorsConfig):
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels # TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
...@@ -628,43 +606,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -628,43 +606,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if expert_map is not None:
raise NotImplementedError( topk_weights, topk_ids, router_logits = topk_output
"Expert Parallelism is not supported for " "fused Marlin MoE method."
)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -78,6 +78,7 @@ from sglang.srt.utils import ( ...@@ -78,6 +78,7 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
_is_hip = is_hip() _is_hip = is_hip()
...@@ -971,15 +972,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -971,15 +972,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
...@@ -987,26 +981,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -987,26 +981,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu( x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x apply_router_weight_on_input, topk_weights, x
) )
...@@ -1032,8 +1011,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1032,8 +1011,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
ret = self.maybe_apply_hip_fused_experts( ret = self.maybe_apply_hip_fused_experts(
layer, layer,
x, x,
topk_weights, topk_output,
topk_ids,
activation, activation,
no_combine, no_combine,
) )
...@@ -1048,6 +1026,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1048,6 +1026,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
): ):
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
topk_weights, topk_ids, _ = topk_output
return cutlass_fused_experts_fp8( return cutlass_fused_experts_fp8(
x, x,
layer.w13_weight.transpose(1, 2), layer.w13_weight.transpose(1, 2),
...@@ -1076,8 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1076,8 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_output=topk_output,
topk_ids=topk_ids,
inplace=inplace and not no_combine, inplace=inplace and not no_combine,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
...@@ -1101,11 +1079,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1101,11 +1079,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_output: TopKOutput,
topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
no_combine: bool = False, no_combine: bool = False,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
topk_weights, topk_ids, _ = topk_output
if _use_hip_int4: if _use_hip_int4:
# TODO: add triton kernel and add check _use_aiter # TODO: add triton kernel and add check _use_aiter
assert not no_combine, f"{no_combine=} is not supported." assert not no_combine, f"{no_combine=} is not supported."
...@@ -1397,14 +1375,8 @@ class Fp8EPMoEMethod(Fp8MoEMethod): ...@@ -1397,14 +1375,8 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
......
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from fractions import Fraction from fractions import Fraction
from typing import Any, Callable, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import torch import torch
...@@ -43,6 +43,9 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -43,6 +43,9 @@ from sglang.srt.layers.quantization.utils import (
unpack_cols, unpack_cols,
) )
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
except ImportError: except ImportError:
...@@ -1057,42 +1060,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -1057,42 +1060,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# Delay the import to avoid circular dependency # Delay the import to avoid circular dependency
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
assert (
scoring_func == "softmax"
), "Only softmax score func is supported for now."
# The input must currently be float16 # The input must currently be float16
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.half() x = x.half()
topk_weights, topk_ids = select_experts( topk_weights, topk_ids, router_logits = topk_output
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=e_score_correction_bias,
)
return fused_marlin_moe( return fused_marlin_moe(
x, x,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -31,6 +31,9 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -31,6 +31,9 @@ from sglang.srt.layers.quantization.utils import (
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_cuda, next_power_of_2 from sglang.srt.utils import is_cuda, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
if is_cuda(): if is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
...@@ -402,15 +405,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -402,15 +405,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
...@@ -418,29 +414,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -418,29 +414,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts( return fused_experts(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_output=topk_output,
topk_ids=topk_ids,
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
...@@ -961,15 +940,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -961,15 +940,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
...@@ -982,21 +954,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -982,21 +954,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if self.enable_flashinfer_moe: if self.enable_flashinfer_moe:
assert ( assert (
...@@ -1004,6 +961,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1004,6 +961,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
), "apply_router_weight_on_input is not supported for Flashinfer" ), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint # and fp4 quantized weights loaded from the checkpoint
topk_weights, topk_ids, _ = topk_output
output = flashinfer_cutlass_fused_moe( output = flashinfer_cutlass_fused_moe(
x, x,
topk_ids.to(torch.int), topk_ids.to(torch.int),
...@@ -1029,6 +987,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1029,6 +987,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
topk_weights, topk_ids, _ = topk_output
return cutlass_moe_fp4( return cutlass_moe_fp4(
a=x, a=x,
a1_gscale=layer.w13_input_scale_quant, a1_gscale=layer.w13_input_scale_quant,
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
import numpy as np
import torch import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
...@@ -20,6 +21,9 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs ...@@ -20,6 +21,9 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
def get_weight_perm(num_bits: int): def get_weight_perm(num_bits: int):
perm_list: List[int] = [] perm_list: List[int] = []
...@@ -348,15 +352,8 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -348,15 +352,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
...@@ -365,22 +362,8 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -365,22 +362,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
) -> torch.Tensor: ) -> torch.Tensor:
# avoid circular import # avoid circular import
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp has_zp = self.quant_config.has_zp
...@@ -389,8 +372,7 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -389,8 +372,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
x, x,
layer.w13_qweight, layer.w13_qweight,
layer.w2_qweight, layer.w2_qweight,
topk_weights=topk_weights, topk_output=topk_output,
topk_ids=topk_ids,
inplace=inplace, inplace=inplace,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_int4_w4a16=weight_bits == 4, use_int4_w4a16=weight_bits == 4,
......
from __future__ import annotations
import importlib import importlib
from typing import Callable, List, Optional from typing import TYPE_CHECKING, Callable, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -21,6 +23,9 @@ from sglang.srt.utils import ( ...@@ -21,6 +23,9 @@ from sglang.srt.utils import (
use_intel_amx_backend, use_intel_amx_backend,
) )
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
...@@ -125,25 +130,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -125,25 +130,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
super().__init__() super().__init__()
self.use_triton_kernels = use_triton_kernels self.use_triton_kernels = use_triton_kernels
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
if has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward,
)
else:
triton_kernel_moe_forward = None
else:
fused_experts = None # type: ignore
triton_kernel_moe_forward = None
self.moe_forward_native = moe_forward_native
self.fused_experts = fused_experts
self.triton_kernel_moe_forward = triton_kernel_moe_forward
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -201,34 +187,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -201,34 +187,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward( return self.forward(
x=x, x=x,
layer=layer, layer=layer,
router_logits=router_logits, topk_output=topk_output,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace, inplace=inplace,
...@@ -240,15 +210,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -240,15 +210,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool, topk_output: TopKOutput,
top_k: int, *,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
...@@ -257,33 +220,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -257,33 +220,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor: ) -> torch.Tensor:
if self.use_triton_kernels: if self.use_triton_kernels:
return self.triton_kernel_moe_forward( # TODO(ch-wan): re-enable the Triton kernel
hidden_states=x, raise NotImplementedError("The Triton kernel is temporarily disabled.")
w1=layer.w13_weight, # return triton_kernel_moe_forward(
w2=layer.w2_weight, # hidden_states=x,
gating_output=router_logits, # w1=layer.w13_weight,
topk=top_k, # w2=layer.w2_weight,
renormalize=renormalize, # gating_output=router_logits,
) # topk=top_k,
# renormalize=renormalize,
# )
else: else:
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if _use_aiter: if _use_aiter:
assert not no_combine, "unsupported" assert not no_combine, "unsupported"
topk_weights, topk_ids, _ = topk_output
if apply_router_weight_on_input: if apply_router_weight_on_input:
assert ( assert (
topk_weights.dim() == 2 topk_weights.dim() == 2
...@@ -296,7 +246,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -296,7 +246,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights = torch.ones_like( topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32 topk_weights, dtype=torch.float32
) # topk_weights must be FP32 (float32) ) # topk_weights must be FP32 (float32)
return fused_moe( return fused_moe(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -310,12 +259,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -310,12 +259,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
), ),
) )
else: else:
return self.fused_experts( from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_experts,
)
return fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_output=topk_output,
topk_ids=topk_ids,
inplace=inplace and not no_combine, inplace=inplace and not no_combine,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
...@@ -327,15 +279,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -327,15 +279,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool, topk_output: TopKOutput,
top_k: int, *,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
...@@ -344,30 +289,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -344,30 +289,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported." assert activation == "silu", f"activation = {activation} is not supported."
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
from sglang.srt.layers.moe.topk import (
apply_topk_weights_cpu,
select_experts,
)
topk_weights, topk_ids = select_experts( topk_weights, topk_ids, _ = topk_output
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
x, topk_weights = apply_topk_weights_cpu( x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x apply_router_weight_on_input, topk_weights, x
) )
return torch.ops.sgl_kernel.fused_experts_cpu( return torch.ops.sgl_kernel.fused_experts_cpu(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -385,61 +313,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -385,61 +313,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
True, # is_vnni True, # is_vnni
) )
else: else:
return self.moe_forward_native( from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
return moe_forward_native(
layer, layer,
x, x,
use_grouped_topk, topk_output,
top_k, activation=activation,
router_logits, apply_router_weight_on_input=apply_router_weight_on_input,
renormalize, inplace=inplace,
topk_group, no_combine=no_combine,
num_expert_group, routed_scaling_factor=routed_scaling_factor,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
) )
def forward_npu( def forward_npu(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
use_grouped_topk: bool, topk_output: TopKOutput,
top_k: int, *,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.moe_forward_native( from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
return moe_forward_native(
layer, layer,
x, x,
use_grouped_topk, topk_output,
top_k, activation=activation,
router_logits, apply_router_weight_on_input=apply_router_weight_on_input,
renormalize, inplace=inplace,
topk_group, no_combine=no_combine,
num_expert_group, routed_scaling_factor=routed_scaling_factor,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
) )
def forward_tpu(self, *args, **kwargs) -> torch.Tensor: def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
...@@ -508,13 +417,7 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -508,13 +417,7 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -25,6 +25,9 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -25,6 +25,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
) )
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
...@@ -266,45 +269,23 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): ...@@ -266,45 +269,23 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts( return fused_experts(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_output=topk_output,
topk_ids=topk_ids,
inplace=inplace, inplace=inplace,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation, activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
per_channel_quant=True, per_channel_quant=True,
......
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import importlib import importlib
import sys import sys
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -37,6 +37,9 @@ from sglang.srt.utils import ( ...@@ -37,6 +37,9 @@ from sglang.srt.utils import (
use_intel_amx_backend, use_intel_amx_backend,
) )
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
...@@ -239,7 +242,7 @@ class W8A8Int8Config(QuantizationConfig): ...@@ -239,7 +242,7 @@ class W8A8Int8Config(QuantizationConfig):
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
) -> Optional[QuantizeMethodBase]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if _is_npu: if _is_npu:
...@@ -469,15 +472,8 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ...@@ -469,15 +472,8 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
top_k: int, *,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
...@@ -485,26 +481,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ...@@ -485,26 +481,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu( x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x apply_router_weight_on_input, topk_weights, x
) )
...@@ -529,8 +510,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ...@@ -529,8 +510,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_output=topk_output,
topk_ids=topk_ids,
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
...@@ -907,7 +887,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): ...@@ -907,7 +887,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: List[int], intermediate_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
) -> None: ) -> None:
...@@ -984,52 +964,11 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): ...@@ -984,52 +964,11 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
self, self,
layer, layer,
x, x,
router_logits, topk_output: TopKOutput,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
routed_scaling_factor,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids, _ = topk_output
global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k,
bias=correction_bias,
k_group=topk_group,
group_count=num_expert_group,
group_select_mode=1,
renorm=0,
norm_type=1,
routed_scaling_factor=1,
eps=float(1e-20),
)
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
torch_native=True,
routed_scaling_factor=routed_scaling_factor,
)
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
return npu_fused_experts( return npu_fused_experts(
...@@ -1040,5 +979,5 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): ...@@ -1040,5 +979,5 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
top_k=top_k, top_k=topk_ids.shape[1],
) )
...@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import ( ...@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
...@@ -109,7 +110,10 @@ class DeepseekMoE(nn.Module): ...@@ -109,7 +110,10 @@ class DeepseekMoE(nn.Module):
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.n_routed_experts}." f"the number of experts {self.n_routed_experts}."
) )
self.topk = TopK(
top_k=self.top_k,
renormalize=config.norm_topk_prob,
)
self.experts = nn.ModuleList( self.experts = nn.ModuleList(
[ [
DeepseekMLP( DeepseekMLP(
...@@ -170,13 +174,12 @@ class DeepseekMoE(nn.Module): ...@@ -170,13 +174,12 @@ class DeepseekMoE(nn.Module):
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = fused_moe.fused_moe( final_hidden_states = fused_moe.fused_moe(
hidden_states, hidden_states,
self.w1, w1=self.w1,
self.w2, w2=self.w2,
router_logits, topk_output=topk_output,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True, inplace=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment