Unverified Commit 15ad6c90 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[1/N] MoE Refactor: refactor `select_experts` (#7966)

parent cfab0ff6
......@@ -29,15 +29,18 @@ class CustomOp(nn.Module):
self._original_forward_method = self._forward_method
# NOTE: Temporarily workaround MoE
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs=1
if "FusedMoE" in self.__class__.__name__:
if num_tokens == 1:
from sglang.srt.layers.moe.fused_moe_native import (
fused_moe_forward_native,
)
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
self._forward_method = fused_moe_forward_native
elif "TopK" in self.__class__.__name__:
if num_tokens == 1:
self._forward_method = self.forward_native
else:
self._forward_method = self.forward_native
self.is_torch_compile = True
......
......@@ -756,7 +756,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional["QuantizationConfig"] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
......
import logging
from typing import Callable, List, Optional, Tuple
from typing import List, Optional, Tuple
import einops
import torch
from torch.nn import Module
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
......@@ -28,7 +24,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
......@@ -162,16 +158,9 @@ class EPMoE(torch.nn.Module):
intermediate_size: int,
layer_id: int,
params_dtype: Optional[torch.dtype] = None,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
use_per_token_if_dynamic: bool = True,
......@@ -189,24 +178,12 @@ class EPMoE(torch.nn.Module):
self.layer_id = layer_id
self.num_experts = num_experts
assert self.num_experts % self.tp_size == 0
assert (
num_fused_shared_experts == 0
), "num_fused_shared_experts is not supported in EP"
self.num_fused_shared_experts = num_fused_shared_experts
self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
self.top_k = top_k
self.intermediate_size = intermediate_size
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.custom_routing_function = custom_routing_function
self.activation = activation
self.routed_scaling_factor = routed_scaling_factor
self.use_per_token_if_dynamic = use_per_token_if_dynamic
......@@ -311,33 +288,24 @@ class EPMoE(torch.nn.Module):
)
return (local_num_experts, expert_map)
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm(hidden_states, router_logits)
return self.forward_deepgemm(hidden_states, topk_output)
else:
return self.forward_normal(hidden_states, router_logits)
return self.forward_normal(hidden_states, topk_output)
def forward_deepgemm(
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
self,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
):
assert self.quant_method is not None
assert self.activation == "silu"
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
)
topk_weights, topk_ids, _ = topk_output
if not self.use_block_quant:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
......@@ -469,8 +437,10 @@ class EPMoE(torch.nn.Module):
)
return output
def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert self.quant_method is not None
topk_weights, topk_ids, _ = topk_output
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
......@@ -481,23 +451,6 @@ class EPMoE(torch.nn.Module):
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
)
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
if self.use_w4afp8:
local_topk_ids = topk_ids
if self.expert_map is not None:
......@@ -916,16 +869,9 @@ class DeepEPMoE(EPMoE):
intermediate_size: int,
layer_id: int,
params_dtype: Optional[torch.dtype] = None,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
deepep_mode: DeepEPMode = DeepEPMode.auto,
......@@ -937,16 +883,9 @@ class DeepEPMoE(EPMoE):
intermediate_size=intermediate_size,
layer_id=layer_id,
params_dtype=params_dtype,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
topk_group=topk_group,
quant_config=quant_config,
tp_size=tp_size,
prefix=prefix,
correction_bias=correction_bias,
custom_routing_function=custom_routing_function,
activation=activation,
routed_scaling_factor=routed_scaling_factor,
)
......
......@@ -9,21 +9,14 @@ import torch
from torch.nn import functional as F
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.moe.topk import TopKOutput
def fused_moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
......@@ -34,20 +27,7 @@ def fused_moe_forward_native(
if apply_router_weight_on_input:
raise NotImplementedError()
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
torch_native=True,
)
topk_weights, topk_ids, _ = topk_output
w13_weights = layer.w13_weight[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
......@@ -67,15 +47,8 @@ def fused_moe_forward_native(
def moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
......@@ -86,20 +59,7 @@ def moe_forward_native(
if apply_router_weight_on_input:
raise NotImplementedError()
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
torch_native=True,
routed_scaling_factor=routed_scaling_factor,
)
topk_weights, topk_ids, _ = topk_output
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
len_experts = layer.num_experts
......
......@@ -6,13 +6,13 @@ import functools
import json
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
scaled_fp8_quant,
......@@ -1328,8 +1328,7 @@ def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
topk_output: TopKOutput,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
......@@ -1348,7 +1347,7 @@ def fused_experts(
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
):
topk_weights, topk_ids, _ = topk_output
if inplace:
assert not no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts(
......@@ -1732,17 +1731,10 @@ def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
topk_output: TopKOutput,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
......@@ -1766,16 +1758,9 @@ def fused_moe(
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- topk_output (TopKOutput): The top-k output of the experts.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseek V2/V3/R1 series models use grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
......@@ -1799,28 +1784,12 @@ def fused_moe(
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=gating_output,
use_grouped_topk=use_grouped_topk,
top_k=topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
topk_output,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
......
......@@ -2,7 +2,7 @@
import logging
from enum import Enum
from typing import Callable, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
......@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
......@@ -59,22 +60,15 @@ class FusedMoE(torch.nn.Module):
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
top_k: Optional[int] = None,
layer_id: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_presharded_weights: bool = False,
......@@ -89,6 +83,7 @@ class FusedMoE(torch.nn.Module):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.top_k = top_k
self.hidden_size = hidden_size
self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
......@@ -126,19 +121,9 @@ class FusedMoE(torch.nn.Module):
self.ep_rank = 0
self.local_num_experts = num_experts
self.routed_scaling_factor = routed_scaling_factor
self.top_k = top_k
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.activation = activation
self.apply_router_weight_on_input = apply_router_weight_on_input
self.use_presharded_weights = use_presharded_weights
......@@ -562,22 +547,14 @@ class FusedMoE(torch.nn.Module):
)
return
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert self.quant_method is not None
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
topk_output=topk_output,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
......
......@@ -12,12 +12,15 @@
# limitations under the License.
# ==============================================================================
from __future__ import annotations
import math
from typing import Callable, Optional
from typing import TYPE_CHECKING, Callable, NamedTuple, Optional
import torch
import torch.nn.functional as F
from sglang.srt.custom_op import CustomOp
from sglang.srt.eplb import expert_location_dispatch
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location_dispatch import (
......@@ -52,6 +55,168 @@ if _use_aiter:
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
if _is_npu:
import torch_npu
class TopKOutput(NamedTuple):
topk_weights: torch.Tensor
topk_ids: torch.Tensor
router_logits: torch.Tensor
class TopK(CustomOp):
# TODO(ch-wan): support triton_kernels
def __init__(
self,
top_k: int,
*,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
renormalize: bool = True,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
super().__init__()
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.top_k = top_k
self.use_grouped_topk = use_grouped_topk
self.renormalize = renormalize
self.topk_group = topk_group
self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
def forward_native(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
torch_native = True
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
torch_native = False
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def forward_cpu(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def forward_npu(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
return torch_npu.npu_moe_gating_top_k(
router_logits,
k=self.top_k,
bias=self.correction_bias,
k_group=self.topk_group,
group_count=self.num_expert_group,
group_select_mode=1,
renorm=0,
norm_type=1,
routed_scaling_factor=1,
eps=float(1e-20),
)
else:
torch_native = True
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def fused_topk_torch_native(
hidden_states: torch.Tensor,
......@@ -436,8 +601,9 @@ def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
*,
use_grouped_topk: bool = False,
renormalize: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
......@@ -447,7 +613,7 @@ def select_experts(
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
) -> TopKOutput:
router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs(
router_logits=router_logits,
......@@ -522,4 +688,4 @@ def select_experts(
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
return topk_weights, topk_ids
return TopKOutput(topk_weights, topk_ids, router_logits)
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
from __future__ import annotations
import builtins
import inspect
from typing import Callable, Dict, Optional, Type, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
import torch
......@@ -65,6 +67,9 @@ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
# Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"fp8": Fp8Config,
......@@ -186,15 +191,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
......@@ -208,20 +206,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
"self": self,
"layer": layer,
"x": x,
"router_logits": router_logits,
"top_k": top_k,
"renormalize": renormalize,
"use_grouped_topk": use_grouped_topk,
"topk_group": topk_group,
"num_expert_group": num_expert_group,
"custom_routing_function": custom_routing_function,
"topk_output": topk_output,
}
if correction_bias is not None:
if not has_correction_bias:
raise ValueError(
"Please increase the version of your vllm. Try `pip install vllm==0.9.0.1`"
)
kwargs["e_score_correction_bias"] = correction_bias
return original_apply(**kwargs)
setattr(class_obj, "apply", new_apply)
......
......@@ -3,7 +3,7 @@ from __future__ import annotations
import logging
import warnings
from typing import Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch
......@@ -33,6 +33,9 @@ from sglang.srt.layers.quantization.scalar_type import scalar_types
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import replace_parameter
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
try:
from vllm import _custom_ops as ops
......@@ -737,45 +740,19 @@ class AWQMoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
topk_output: TopKOutput,
*,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
**kwargs,
) -> torch.Tensor:
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported."
assert (
scoring_func == "softmax"
), "Only softmax score func is supported for now."
# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
topk_weights, topk_ids, router_logits = topk_output
return fused_marlin_moe(
x,
......
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py
from __future__ import annotations
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
import torch
from torch import nn
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
......@@ -88,19 +92,22 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError()
raise NotImplementedError
@abstractmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
class QuantizationConfig(ABC):
......
......@@ -3,7 +3,7 @@
from __future__ import annotations
import logging
from typing import Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch
from torch.nn import Module
......@@ -21,6 +21,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = logging.getLogger(__name__)
......@@ -344,15 +347,8 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
......@@ -360,30 +356,13 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
# Expert fusion with INT8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk_output=topk_output,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
......
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import enum
import logging
from enum import Enum
from typing import Callable, List, Optional
from typing import TYPE_CHECKING, List, Optional
import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import (
......@@ -20,6 +22,12 @@ from sglang.srt.layers.quantization.utils import (
)
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
_is_cuda = is_cuda()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
......@@ -51,7 +59,7 @@ __all__ = [
]
class CompressedTensorsMoEMethod:
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def __new__(cls, *args, **kwargs):
if cls is CompressedTensorsMoEMethod:
return super().__new__(cls)
......@@ -59,7 +67,7 @@ class CompressedTensorsMoEMethod:
@staticmethod
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
quant_config: CompressedTensorsConfig,
) -> "CompressedTensorsMoEMethod":
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
......@@ -82,9 +90,7 @@ class CompressedTensorsMoEMethod:
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
def __init__(self, quant_config: CompressedTensorsConfig):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
......@@ -270,47 +276,21 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
apply_router_weight_on_input: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk_output=topk_output,
inplace=inplace,
activation=activation,
use_fp8_w8a8=True,
......@@ -327,9 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
def __init__(self, quant_config: CompressedTensorsConfig):
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
......@@ -628,43 +606,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
**kwargs,
) -> torch.Tensor:
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported."
if expert_map is not None:
raise NotImplementedError(
"Expert Parallelism is not supported for " "fused Marlin MoE method."
)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
topk_weights, topk_ids, router_logits = topk_output
return torch.ops.vllm.fused_marlin_moe(
x,
......
......@@ -3,7 +3,7 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
......@@ -78,6 +78,7 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
_is_hip = is_hip()
......@@ -971,15 +972,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
......@@ -987,26 +981,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
)
......@@ -1032,8 +1011,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
ret = self.maybe_apply_hip_fused_experts(
layer,
x,
topk_weights,
topk_ids,
topk_output,
activation,
no_combine,
)
......@@ -1048,6 +1026,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
):
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
topk_weights, topk_ids, _ = topk_output
return cutlass_fused_experts_fp8(
x,
layer.w13_weight.transpose(1, 2),
......@@ -1076,8 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk_output=topk_output,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
......@@ -1101,11 +1079,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
topk_output: TopKOutput,
activation: str = "silu",
no_combine: bool = False,
) -> Optional[torch.Tensor]:
topk_weights, topk_ids, _ = topk_output
if _use_hip_int4:
# TODO: add triton kernel and add check _use_aiter
assert not no_combine, f"{no_combine=} is not supported."
......@@ -1397,14 +1375,8 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
) -> torch.Tensor:
raise NotImplementedError
......
......@@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from fractions import Fraction
from typing import Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import torch
......@@ -43,6 +43,9 @@ from sglang.srt.layers.quantization.utils import (
unpack_cols,
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
try:
from vllm import _custom_ops as ops
except ImportError:
......@@ -1057,42 +1060,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
**kwargs,
) -> torch.Tensor:
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported."
assert (
scoring_func == "softmax"
), "Only softmax score func is supported for now."
# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=e_score_correction_bias,
)
topk_weights, topk_ids, router_logits = topk_output
return fused_marlin_moe(
x,
......
......@@ -2,7 +2,7 @@
from __future__ import annotations
import logging
from typing import Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
......@@ -31,6 +31,9 @@ from sglang.srt.layers.quantization.utils import (
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_cuda, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
if is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
......@@ -402,15 +405,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
......@@ -418,29 +414,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk_output=topk_output,
inplace=inplace,
activation=activation,
use_fp8_w8a8=True,
......@@ -961,15 +940,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
......@@ -982,21 +954,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if self.enable_flashinfer_moe:
assert (
......@@ -1004,6 +961,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
topk_weights, topk_ids, _ = topk_output
output = flashinfer_cutlass_fused_moe(
x,
topk_ids.to(torch.int),
......@@ -1029,6 +987,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
topk_weights, topk_ids, _ = topk_output
return cutlass_moe_fp4(
a=x,
a1_gscale=layer.w13_input_scale_quant,
......
......@@ -2,8 +2,9 @@
from __future__ import annotations
import logging
from typing import Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import numpy as np
import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank
......@@ -20,6 +21,9 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
def get_weight_perm(num_bits: int):
perm_list: List[int] = []
......@@ -348,15 +352,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
......@@ -365,22 +362,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
) -> torch.Tensor:
# avoid circular import
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
......@@ -389,8 +372,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk_output=topk_output,
inplace=inplace,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int4_w4a16=weight_bits == 4,
......
from __future__ import annotations
import importlib
from typing import Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional
import torch
import torch.nn.functional as F
......@@ -21,6 +23,9 @@ from sglang.srt.utils import (
use_intel_amx_backend,
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
......@@ -125,25 +130,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
super().__init__()
self.use_triton_kernels = use_triton_kernels
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
if has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward,
)
else:
triton_kernel_moe_forward = None
else:
fused_experts = None # type: ignore
triton_kernel_moe_forward = None
self.moe_forward_native = moe_forward_native
self.fused_experts = fused_experts
self.triton_kernel_moe_forward = triton_kernel_moe_forward
def create_weights(
self,
layer: torch.nn.Module,
......@@ -201,34 +187,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return self.forward(
x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
topk_output=topk_output,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
......@@ -240,15 +210,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
......@@ -257,33 +220,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor:
if self.use_triton_kernels:
return self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
# TODO(ch-wan): re-enable the Triton kernel
raise NotImplementedError("The Triton kernel is temporarily disabled.")
# return triton_kernel_moe_forward(
# hidden_states=x,
# w1=layer.w13_weight,
# w2=layer.w2_weight,
# gating_output=router_logits,
# topk=top_k,
# renormalize=renormalize,
# )
else:
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if _use_aiter:
assert not no_combine, "unsupported"
topk_weights, topk_ids, _ = topk_output
if apply_router_weight_on_input:
assert (
topk_weights.dim() == 2
......@@ -296,7 +246,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32
) # topk_weights must be FP32 (float32)
return fused_moe(
x,
layer.w13_weight,
......@@ -310,12 +259,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
),
)
else:
return self.fused_experts(
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_experts,
)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk_output=topk_output,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
......@@ -327,15 +279,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
......@@ -344,30 +289,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported."
if use_intel_amx_backend(layer):
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
from sglang.srt.layers.moe.topk import (
apply_topk_weights_cpu,
select_experts,
)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
......@@ -385,61 +313,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
True, # is_vnni
)
else:
return self.moe_forward_native(
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
topk_output,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def forward_npu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return self.moe_forward_native(
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
topk_output,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
......@@ -508,13 +417,7 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
) -> torch.Tensor:
raise NotImplementedError
from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
......@@ -25,6 +25,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
)
from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
_is_fp8_fnuz = is_fp8_fnuz()
......@@ -266,45 +269,23 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk_output=topk_output,
inplace=inplace,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
use_fp8_w8a8=True,
per_channel_quant=True,
......
......@@ -3,7 +3,7 @@ from __future__ import annotations
import importlib
import sys
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
import torch
from torch.nn.parameter import Parameter
......@@ -37,6 +37,9 @@ from sglang.srt.utils import (
use_intel_amx_backend,
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
_is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
......@@ -239,7 +242,7 @@ class W8A8Int8Config(QuantizationConfig):
layer: torch.nn.Module,
prefix: str,
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if _is_npu:
......@@ -469,15 +472,8 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
......@@ -485,26 +481,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
)
......@@ -529,8 +510,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk_output=topk_output,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
......@@ -907,7 +887,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: List[int],
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
......@@ -984,52 +964,11 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
self,
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
routed_scaling_factor,
topk_output: TopKOutput,
**kwargs,
) -> torch.Tensor:
from sglang.srt.layers.moe.topk import select_experts
global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k,
bias=correction_bias,
k_group=topk_group,
group_count=num_expert_group,
group_select_mode=1,
renorm=0,
norm_type=1,
routed_scaling_factor=1,
eps=float(1e-20),
)
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
torch_native=True,
routed_scaling_factor=routed_scaling_factor,
)
topk_weights, topk_ids, _ = topk_output
topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype)
return npu_fused_experts(
......@@ -1040,5 +979,5 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
top_k=topk_ids.shape[1],
)
......@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
......@@ -109,7 +110,10 @@ class DeepseekMoE(nn.Module):
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.n_routed_experts}."
)
self.topk = TopK(
top_k=self.top_k,
renormalize=config.norm_topk_prob,
)
self.experts = nn.ModuleList(
[
DeepseekMLP(
......@@ -170,13 +174,12 @@ class DeepseekMoE(nn.Module):
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = fused_moe.fused_moe(
hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
w1=self.w1,
w2=self.w2,
topk_output=topk_output,
inplace=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment