Unverified Commit 3fa62da7 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[7/N] MoE Refactor: the implementation of new framework (#9269)

parent dbb1235d
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
from __future__ import annotations
import logging import logging
import math import math
import os import os
...@@ -19,17 +22,19 @@ from abc import ABC ...@@ -19,17 +22,19 @@ from abc import ABC
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
import einops import einops
import torch import torch
import torch.distributed import torch.distributed
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var from sglang.srt.utils import Withable, get_bool_env_var
if TYPE_CHECKING:
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --------------------------------------- Entrypoint ----------------------------------------- # --------------------------------------- Entrypoint -----------------------------------------
...@@ -43,7 +48,7 @@ class ExpertDistributionRecorder(ABC): ...@@ -43,7 +48,7 @@ class ExpertDistributionRecorder(ABC):
@staticmethod @staticmethod
def init_new( def init_new(
server_args: ServerArgs, server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata", expert_location_metadata: ExpertLocationMetadata,
rank: int, rank: int,
): ):
if server_args.expert_distribution_recorder_mode is not None: if server_args.expert_distribution_recorder_mode is not None:
...@@ -118,7 +123,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder): ...@@ -118,7 +123,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
def __init__( def __init__(
self, self,
server_args: ServerArgs, server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata", expert_location_metadata: ExpertLocationMetadata,
rank: int, rank: int,
): ):
self._server_args = server_args self._server_args = server_args
...@@ -279,7 +284,7 @@ class _SinglePassGatherer(ABC): ...@@ -279,7 +284,7 @@ class _SinglePassGatherer(ABC):
@staticmethod @staticmethod
def init_new( def init_new(
server_args: ServerArgs, server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata", expert_location_metadata: ExpertLocationMetadata,
rank: int, rank: int,
) -> "_SinglePassGatherer": ) -> "_SinglePassGatherer":
if server_args.expert_distribution_recorder_mode == "per_token": if server_args.expert_distribution_recorder_mode == "per_token":
...@@ -307,7 +312,7 @@ class _SinglePassGatherer(ABC): ...@@ -307,7 +312,7 @@ class _SinglePassGatherer(ABC):
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank) return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): def __init__(self, expert_location_metadata: ExpertLocationMetadata, rank: int):
self._expert_location_metadata = expert_location_metadata self._expert_location_metadata = expert_location_metadata
self._rank = rank self._rank = rank
...@@ -346,7 +351,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer): ...@@ -346,7 +351,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
def __init__( def __init__(
self, self,
server_args: ServerArgs, server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata", expert_location_metadata: ExpertLocationMetadata,
rank: int, rank: int,
): ):
super().__init__(expert_location_metadata, rank) super().__init__(expert_location_metadata, rank)
...@@ -561,7 +566,7 @@ class _Accumulator(ABC): ...@@ -561,7 +566,7 @@ class _Accumulator(ABC):
@staticmethod @staticmethod
def init_new( def init_new(
server_args: ServerArgs, server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata", expert_location_metadata: ExpertLocationMetadata,
rank: int, rank: int,
) -> "_Accumulator": ) -> "_Accumulator":
return _Accumulator.get_class(server_args)( return _Accumulator.get_class(server_args)(
...@@ -580,7 +585,7 @@ class _Accumulator(ABC): ...@@ -580,7 +585,7 @@ class _Accumulator(ABC):
def __init__( def __init__(
self, self,
server_args: ServerArgs, server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata", expert_location_metadata: ExpertLocationMetadata,
rank: int, rank: int,
): ):
self._server_args = server_args self._server_args = server_args
......
...@@ -11,21 +11,26 @@ ...@@ -11,21 +11,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
from __future__ import annotations
import json import json
import logging import logging
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import TYPE_CHECKING, List, Optional
import torch import torch
import torch.distributed import torch.distributed
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.eplb import eplb_algorithms from sglang.srt.eplb import eplb_algorithms
from sglang.srt.model_loader import get_model_architecture from sglang.srt.model_loader import get_model_architecture
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig
from sglang.srt.layers.moe.utils import ( from sglang.srt.layers.moe.utils import (
DeepEPMode, DeepEPMode,
MoeA2ABackend, MoeA2ABackend,
...@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import ( ...@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import (
__all__ = [ __all__ = [
"DeepEPMode", "DeepEPMode",
"MoeA2ABackend", "MoeA2ABackend",
"MoeRunner",
"MoeRunnerConfig", "MoeRunnerConfig",
"MoeRunnerBackend", "MoeRunnerBackend",
"initialize_moe_config", "initialize_moe_config",
......
...@@ -8,16 +8,18 @@ from torch.nn import functional as F ...@@ -8,16 +8,18 @@ from torch.nn import functional as F
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput
from sglang.srt.layers.moe.topk import StandardTopKOutput from sglang.srt.layers.moe.topk import StandardTopKOutput
def fused_moe_forward_native( def fused_moe_forward_native(
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: ) -> torch.Tensor:
x, topk_output = dispatch_output
moe_runner_config = layer.moe_runner_config
if moe_runner_config.apply_router_weight_on_input: if moe_runner_config.apply_router_weight_on_input:
raise NotImplementedError() raise NotImplementedError()
......
# NOTE: this file will be separated into sglang/srt/layers/moe/moe_runner/triton_utils.py
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py
"""Fused MoE kernel.""" """Fused MoE kernel."""
...@@ -6,13 +7,12 @@ from __future__ import annotations ...@@ -6,13 +7,12 @@ from __future__ import annotations
import functools import functools
import os import os
from typing import List, Optional from typing import TYPE_CHECKING, List, Optional
import torch import torch
import triton.language as tl import triton.language as tl
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
direct_register_custom_op, direct_register_custom_op,
...@@ -26,6 +26,9 @@ from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_c ...@@ -26,6 +26,9 @@ from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_c
from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton
from .moe_align_block_size import moe_align_block_size from .moe_align_block_size import moe_align_block_size
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import StandardTopKOutput
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
......
...@@ -23,8 +23,13 @@ from sglang.srt.layers.moe import ( ...@@ -23,8 +23,13 @@ from sglang.srt.layers.moe import (
get_moe_runner_backend, get_moe_runner_backend,
should_use_flashinfer_trtllm_moe, should_use_flashinfer_trtllm_moe,
) )
from sglang.srt.layers.moe.token_dispatcher.standard import (
CombineInput,
StandardDispatcher,
)
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
...@@ -152,16 +157,6 @@ class FusedMoE(torch.nn.Module): ...@@ -152,16 +157,6 @@ class FusedMoE(torch.nn.Module):
self.expert_map_cpu = None self.expert_map_cpu = None
self.expert_map_gpu = None self.expert_map_gpu = None
self.moe_runner_config = MoeRunnerConfig(
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
)
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass() enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
if enable_flashinfer_cutlass_moe and quant_config is None: if enable_flashinfer_cutlass_moe and quant_config is None:
...@@ -196,13 +191,6 @@ class FusedMoE(torch.nn.Module): ...@@ -196,13 +191,6 @@ class FusedMoE(torch.nn.Module):
self.use_presharded_weights = use_presharded_weights self.use_presharded_weights = use_presharded_weights
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels
)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
self.quant_config = quant_config self.quant_config = quant_config
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4() self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
...@@ -213,12 +201,40 @@ class FusedMoE(torch.nn.Module): ...@@ -213,12 +201,40 @@ class FusedMoE(torch.nn.Module):
and self.use_flashinfer_mxfp4_moe and self.use_flashinfer_mxfp4_moe
): ):
hidden_size = round_up(hidden_size, 256) hidden_size = round_up(hidden_size, 256)
self.hidden_size = hidden_size
self.moe_runner_config = MoeRunnerConfig(
num_experts=num_experts,
num_local_experts=self.num_local_experts,
hidden_size=hidden_size,
intermediate_size_per_partition=self.intermediate_size_per_partition,
layer_id=layer_id,
top_k=top_k,
num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
)
if quant_config is None:
self.quant_method: FusedMoEMethodBase = UnquantizedFusedMoEMethod(
self.use_triton_kernels
)
else:
self.quant_method: FusedMoEMethodBase = quant_config.get_quant_method(
self, prefix
)
assert self.quant_method is not None
self.quant_method.create_weights( self.quant_method.create_weights(
layer=self, layer=self,
num_experts=self.num_local_experts, num_experts=self.num_local_experts,
hidden_size=hidden_size, hidden_size=hidden_size,
# FIXME: figure out which intermediate_size to use
intermediate_size=self.intermediate_size_per_partition,
intermediate_size_per_partition=self.intermediate_size_per_partition, intermediate_size_per_partition=self.intermediate_size_per_partition,
params_dtype=params_dtype, params_dtype=params_dtype,
weight_loader=( weight_loader=(
...@@ -229,6 +245,9 @@ class FusedMoE(torch.nn.Module): ...@@ -229,6 +245,9 @@ class FusedMoE(torch.nn.Module):
with_bias=with_bias, with_bias=with_bias,
) )
self.quant_method.create_moe_runner(self, self.moe_runner_config)
self.dispatcher = StandardDispatcher()
def _load_per_tensor_weight_scale( def _load_per_tensor_weight_scale(
self, self,
shard_id: str, shard_id: str,
...@@ -811,16 +830,17 @@ class FusedMoE(torch.nn.Module): ...@@ -811,16 +830,17 @@ class FusedMoE(torch.nn.Module):
elif TopKOutputChecker.format_is_triton_kernel(topk_output): elif TopKOutputChecker.format_is_triton_kernel(topk_output):
raise NotImplementedError() raise NotImplementedError()
# Matrix multiply. dispatch_output = self.dispatcher.dispatch(
with use_symmetric_memory(get_tp_group()) as sm: hidden_states=hidden_states, topk_output=topk_output
)
final_hidden_states = self.quant_method.apply( # TODO: consider using symmetric memory
layer=self, combine_input = self.quant_method.apply(
x=hidden_states, layer=self,
topk_output=topk_output, dispatch_output=dispatch_output,
moe_runner_config=self.moe_runner_config, )
)
sm.tag(final_hidden_states) final_hidden_states = self.dispatcher.combine(combine_input)
final_hidden_states = final_hidden_states[ final_hidden_states = final_hidden_states[
..., :origin_hidden_states_dim ..., :origin_hidden_states_dim
...@@ -955,7 +975,6 @@ class FlashInferFusedMoE(FusedMoE): ...@@ -955,7 +975,6 @@ class FlashInferFusedMoE(FusedMoE):
layer=self, layer=self,
x=hidden_states, x=hidden_states,
topk_output=topk_output, topk_output=topk_output,
moe_runner_config=self.moe_runner_config,
) )
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
......
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.runner import MoeRunner
__all__ = ["MoeRunnerConfig"] __all__ = ["MoeRunnerConfig", "MoeRunner"]
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
import torch
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
CombineInputFormat,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner.triton import (
TritonRunnerCore,
TritonRunnerInput,
TritonRunnerOutput,
)
@dataclass @dataclass
class MoeRunnerConfig: class MoeRunnerConfig:
# MoE parameters
num_experts: Optional[int] = None
num_local_experts: Optional[int] = None
hidden_size: Optional[int] = None
intermediate_size_per_partition: Optional[int] = None
layer_id: Optional[int] = None
top_k: Optional[int] = None
num_fused_shared_experts: Optional[int] = None
params_dtype: Optional[torch.dtype] = None
# Runner configuration
activation: str = "silu" activation: str = "silu"
apply_router_weight_on_input: bool = False apply_router_weight_on_input: bool = False
inplace: bool = True inplace: bool = True
...@@ -11,3 +43,254 @@ class MoeRunnerConfig: ...@@ -11,3 +43,254 @@ class MoeRunnerConfig:
routed_scaling_factor: Optional[float] = None routed_scaling_factor: Optional[float] = None
gemm1_alpha: Optional[float] = None gemm1_alpha: Optional[float] = None
gemm1_clamp_limit: Optional[float] = None gemm1_clamp_limit: Optional[float] = None
@dataclass
class RunnerInput(ABC):
@property
@abstractmethod
def runner_backend(self) -> MoeRunnerBackend: ...
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerInput]:
return self.runner_backend == MoeRunnerBackend.TRITON
class RunnerOutput(ABC):
@property
@abstractmethod
def runner_backend(self) -> MoeRunnerBackend: ...
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerOutput]:
return self.runner_backend == MoeRunnerBackend.TRITON
@dataclass
class MoeQuantInfo(ABC):
"""Moe quantization data."""
pass
class MoeRunnerCore(ABC):
def __init__(self, config: MoeRunnerConfig):
self.config = config
@abstractmethod
def run(
self, runner_input: RunnerInput, quant_info: MoeQuantInfo, running_state: dict
) -> RunnerOutput:
pass
@property
@abstractmethod
def runner_backend(self) -> MoeRunnerBackend: ...
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerCore]:
return self.runner_backend == MoeRunnerBackend.TRITON
class FusedOpPool:
_fused_funcs: dict[str, Callable] = {}
@classmethod
def register_fused_func(
cls, a2a_backend_name: str, runner_backend_name: str, fused_func: Callable
):
key = (a2a_backend_name, runner_backend_name)
if key in cls._fused_funcs:
raise ValueError(
f"Fused function for {a2a_backend_name} to {runner_backend_name} is already registered."
)
assert MoeA2ABackend(
a2a_backend_name
), f"Invalid dispatch name: {a2a_backend_name}"
assert MoeRunnerBackend(
runner_backend_name
), f"Invalid runner name: {runner_backend_name}"
cls._fused_funcs[key] = fused_func
@classmethod
def get_fused_func(cls, dispatch_name: str, runner_name: str) -> Optional[Callable]:
key = (dispatch_name, runner_name)
fused_func = cls._fused_funcs.get(key)
return fused_func
class PermuteMethodPool:
_pre_permute_methods: dict[
Tuple[DispatchOutputFormat, MoeRunnerBackend], Callable
] = {}
_post_permute_methods: dict[
Tuple[MoeRunnerBackend, CombineInputFormat], Callable
] = {}
@classmethod
def register_pre_permute(
cls,
dispatch_output_name: str,
runner_backend_name: str,
permute_func: Callable,
):
"""
Register a customized pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
:param dispatch_output_name: The DispatchOutputFormat name.
:param runner_backend_name: The MoeRunnerBackend name.
:param permute_func: The permute function to register.
"""
key = (dispatch_output_name, runner_backend_name)
if key in cls._pre_permute_methods:
raise ValueError(
f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
)
assert DispatchOutputFormat(
dispatch_output_name
), f"Invalid dispatch output name: {dispatch_output_name}"
assert MoeRunnerBackend(
runner_backend_name
), f"Invalid runner backend name: {runner_backend_name}"
cls._pre_permute_methods[key] = permute_func
@classmethod
def register_post_permute(
cls,
runner_backend_name: str,
combine_input_name: str,
permute_func: Callable,
):
"""
Register a customized post-permute function for the given MoeRunnerBackend and CombineInputFormat.
:param runner_backend_name: The MoeRunnerBackend name.
:param combine_input_name: The CombineInputFormat name.
:param permute_func: The permute function to register.
"""
key = (runner_backend_name, combine_input_name)
if key in cls._post_permute_methods:
raise ValueError(
f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
)
assert MoeRunnerBackend(
runner_backend_name
), f"Invalid runner backend name: {runner_backend_name}"
assert CombineInputFormat(
combine_input_name
), f"Invalid combine input name: {combine_input_name}"
cls._post_permute_methods[key] = permute_func
@classmethod
def get_pre_permute(
cls,
dispatch_output_format: DispatchOutputFormat,
runner_input_format: MoeRunnerBackend,
) -> Callable:
"""
Retrieve the pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
:param dispatch_output_format: The DispatchOutputFormat type.
:param runner_input_format: The MoeRunnerBackend type.
:return: The registered permute function or None if not found.
"""
key = (dispatch_output_format, runner_input_format)
pre_permute_func = cls._pre_permute_methods.get(key)
assert (
pre_permute_func is not None
), f"Pre-permute function for {dispatch_output_format} to {runner_input_format} is not registered"
return pre_permute_func
@classmethod
def get_post_permute(
cls,
runner_output_format: MoeRunnerBackend,
combine_input_format: CombineInputFormat,
) -> Callable:
"""
Retrieve the post-permute function for the given MoeRunnerBackend and CombineInputFormat.
:param runner_output_format: The MoeRunnerBackend type.
:param combine_input_format: The CombineInputFormat type.
:return: The registered permute function or None if not found.
"""
key = (runner_output_format, combine_input_format)
post_permute_func = cls._post_permute_methods.get(key)
assert (
post_permute_func is not None
), f"Post-permute function for {runner_output_format} to {combine_input_format} is not registered"
return post_permute_func
def register_fused_func(
a2a_backend_name: str,
runner_backend_name: str,
) -> Callable:
"""
Decorator to register a fused function for the given DispatchOutputFormat and MoeRunnerBackend.
:param a2a_backend_name: The A2A backend name.
:param runner_backend_name: The MoeRunnerBackend name.
:return: The decorator function.
"""
def decorator(fused_func: Callable):
FusedOpPool.register_fused_func(
a2a_backend_name, runner_backend_name, fused_func
)
return fused_func
return decorator
def register_pre_permute(
dispatch_output_name: str,
runner_backend_name: str,
) -> Callable:
"""
Decorator to register a pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
:param dispatch_output_name: The DispatchOutputFormat name.
:param runner_backend_name: The MoeRunnerBackend name.
:return: The decorator function.
"""
def decorator(
permute_func: Callable[
[DispatchOutput, MoeQuantInfo, MoeRunnerConfig, dict], RunnerInput
]
) -> Callable:
PermuteMethodPool.register_pre_permute(
dispatch_output_name, runner_backend_name, permute_func
)
return permute_func
return decorator
def register_post_permute(
runner_backend_name: str,
combine_input_name: str,
) -> Callable:
"""
Decorator to register a post-permute function for the given MoeRunnerBackend and CombineInputFormat.
:param runner_backend_name: The MoeRunnerBackend name.
:param combine_input_name: The CombineInputFormat name.
:return: The decorator function.
"""
def decorator(
permute_func: Callable[
[RunnerOutput, MoeQuantInfo, MoeRunnerConfig, dict], CombineInput
]
) -> Callable:
PermuteMethodPool.register_post_permute(
runner_backend_name, combine_input_name, permute_func
)
return permute_func
return decorator
from __future__ import annotations
import logging
import os
from typing import TYPE_CHECKING
from sglang.srt.layers.moe.moe_runner.base import (
FusedOpPool,
MoeRunnerConfig,
PermuteMethodPool,
)
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
from sglang.srt.layers.moe.token_dispatcher.base import (
CombineInput,
CombineInputFormat,
DispatchOutput,
)
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
from sglang.srt.layers.moe.utils import MoeRunnerBackend
logger = logging.getLogger(__name__)
class MoeRunner:
def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
self.runner_backend = runner_backend
self.config = config
self.fused_func = None
if runner_backend.is_triton():
self.runner_core = TritonRunnerCore(config)
else:
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
a2a_backend_name = get_moe_a2a_backend().value
runner_backend_name = runner_backend.value
self.fused_func = FusedOpPool.get_fused_func(
a2a_backend_name, runner_backend_name
)
SGLANG_CI_DISABLE_MOE_FUSED_FUNC = os.environ.get(
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC", "0"
)
if SGLANG_CI_DISABLE_MOE_FUSED_FUNC == "1":
logger.info(
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC is set to 1, disabling fused func"
)
self.fused_func = None
def run(
self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo
) -> CombineInput:
if self.fused_func is not None:
return self.fused_func(dispatch_output, quant_info, self.config)
dispatch_format = dispatch_output.format.value
runner_format = self.runner_core.runner_backend.value
self.pre_permute_func = PermuteMethodPool.get_pre_permute(
dispatch_format, runner_format
)
running_state = {}
runner_input = self.pre_permute_func(
dispatch_output, quant_info, self.config, running_state
)
runner_output = self.runner_core.run(runner_input, quant_info, running_state)
runner_format = self.runner_core.runner_backend.value
combine_format = dispatch_output.format.value
self.post_permute_func = PermuteMethodPool.get_post_permute(
runner_format, combine_format
)
combine_input = self.post_permute_func(
runner_output, quant_info, self.config, running_state
)
return combine_input
from __future__ import annotations
import functools
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
import torch
import triton.language as tl
from sglang.srt.layers.moe.moe_runner.base import (
MoeQuantInfo,
MoeRunnerConfig,
MoeRunnerCore,
RunnerInput,
RunnerOutput,
register_fused_func,
register_post_permute,
register_pre_permute,
)
from sglang.srt.layers.moe.token_dispatcher import (
StandardCombineInput,
StandardDispatchOutput,
)
from sglang.srt.layers.moe.utils import MoeRunnerBackend
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_use_aiter = bool(int(os.getenv("SGLANG_MOE_USE_AITER", "0")))
_MOE_PADDING_SIZE = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
elif _is_cpu and _is_cpu_amx_available:
pass
elif _is_hip:
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
if _use_aiter:
try:
from aiter import moe_sum
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
@dataclass
class TritonRunnerInput(RunnerInput):
hidden_states: torch.Tensor
topk_weights: torch.Tensor
topk_ids: torch.Tensor
sorted_token_ids: torch.Tensor
expert_ids: torch.Tensor
num_tokens_post_padded: torch.Tensor
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.TRITON
@dataclass
class TritonRunnerOutput(RunnerOutput):
hidden_states: torch.Tensor
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.TRITON
@dataclass
class TritonMoeQuantInfo(MoeQuantInfo):
w13_weight: torch.Tensor
w2_weight: torch.Tensor
b13: Optional[torch.Tensor] = None
b2: Optional[torch.Tensor] = None
use_fp8_w8a8: bool = False
use_int8_w8a8: bool = False
use_int8_w8a16: bool = False
use_int4_w4a16: bool = False
per_channel_quant: bool = False
w13_scale: Optional[torch.Tensor] = None
w2_scale: Optional[torch.Tensor] = None
w13_zp: Optional[torch.Tensor] = None
w2_zp: Optional[torch.Tensor] = None
a13_scale: Optional[torch.Tensor] = None
a2_scale: Optional[torch.Tensor] = None
block_shape: Optional[List[int]] = None
class TritonRunnerCore(MoeRunnerCore):
def __init__(self, config: MoeRunnerConfig):
super().__init__(config)
def run(
self,
runner_input: TritonRunnerInput,
quant_info: TritonMoeQuantInfo,
running_state: dict,
) -> TritonRunnerOutput:
# TODO: move these functions to the triton runner
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
invoke_fused_moe_kernel,
moe_sum_reduce_torch_compile,
moe_sum_reduce_triton,
swiglu_with_alpha_and_limit,
)
hidden_states = runner_input.hidden_states
topk_weights = runner_input.topk_weights
topk_ids = runner_input.topk_ids
sorted_token_ids = runner_input.sorted_token_ids
expert_ids = runner_input.expert_ids
num_tokens_post_padded = runner_input.num_tokens_post_padded
w13 = quant_info.w13_weight
w2 = quant_info.w2_weight
b13 = quant_info.b13
b2 = quant_info.b2
a13_scale = quant_info.a13_scale
a2_scale = quant_info.a2_scale
w13_scale = quant_info.w13_scale
w2_scale = quant_info.w2_scale
w13_zp = quant_info.w13_zp
w2_zp = quant_info.w2_zp
block_shape = quant_info.block_shape
per_channel_quant = quant_info.per_channel_quant
use_fp8_w8a8 = quant_info.use_fp8_w8a8
use_int8_w8a8 = quant_info.use_int8_w8a8
use_int8_w8a16 = quant_info.use_int8_w8a16
use_int4_w4a16 = quant_info.use_int4_w4a16
activation = self.config.activation
no_combine = self.config.no_combine
inplace = self.config.inplace
gemm1_alpha = self.config.gemm1_alpha
gemm1_limit = self.config.gemm1_clamp_limit
routed_scaling_factor = self.config.routed_scaling_factor
apply_router_weight_on_input = self.config.apply_router_weight_on_input
M = hidden_states.shape[0]
E, N, _ = w13.shape
compute_type = (
tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
)
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
invoke_fused_moe_kernel(
hidden_states,
w13,
b13,
intermediate_cache1,
a13_scale,
w13_scale,
w13_zp,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
apply_router_weight_on_input,
topk_ids.shape[1],
running_state["config"],
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if activation == "silu":
if gemm1_alpha is not None:
assert gemm1_limit is not None
intermediate_cache2 = swiglu_with_alpha_and_limit(
intermediate_cache1.view(-1, N),
gemm1_alpha,
gemm1_limit,
)
elif _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
vllm_ops.silu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
elif activation == "gelu":
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
if _is_cuda:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
vllm_ops.gelu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
else:
raise ValueError(f"Unsupported activation: {activation=}")
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if no_combine:
assert not inplace
out_hidden_states = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
elif inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
b2,
(
intermediate_cache3
if not no_combine and topk_ids.shape[1] != 1
else out_hidden_states.unsqueeze(0)
),
a2_scale,
w2_scale,
w2_zp,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
running_state["config"],
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
)
if routed_scaling_factor is None:
routed_scaling_factor = 1.0
if no_combine:
pass
elif _is_cuda:
if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0:
pass # we write directly into out_hidden_states
elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0:
torch.add(
intermediate_cache3[:, 0],
intermediate_cache3[:, 1],
out=out_hidden_states,
).squeeze(dim=1)
else:
# According to micro benchmark results, torch.compile can get better performance for small token.
if M <= 32:
moe_sum_reduce_torch_compile(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states,
routed_scaling_factor,
)
else:
moe_sum_reduce_triton(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states,
routed_scaling_factor,
)
elif _is_hip:
if _use_aiter:
moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states,
)
else:
vllm_ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states,
)
else:
vllm_ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states,
)
return TritonRunnerOutput(
hidden_states=out_hidden_states,
)
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.TRITON
@register_fused_func("none", "triton")
def fused_experts_none_to_triton(
dispatch_output: StandardDispatchOutput,
quant_info: TritonMoeQuantInfo,
runner_config: MoeRunnerConfig,
) -> StandardCombineInput:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
output = fused_experts(
hidden_states=dispatch_output.hidden_states,
w1=quant_info.w13_weight,
w2=quant_info.w2_weight,
topk_output=dispatch_output.topk_output,
moe_runner_config=runner_config,
b1=quant_info.b13,
b2=quant_info.b2,
use_fp8_w8a8=quant_info.use_fp8_w8a8,
use_int8_w8a8=quant_info.use_int8_w8a8,
use_int8_w8a16=quant_info.use_int8_w8a16,
use_int4_w4a16=quant_info.use_int4_w4a16,
per_channel_quant=quant_info.per_channel_quant,
w1_scale=quant_info.w13_scale,
w2_scale=quant_info.w2_scale,
w1_zp=quant_info.w13_zp,
w2_zp=quant_info.w2_zp,
a1_scale=quant_info.a13_scale,
a2_scale=quant_info.a2_scale,
block_shape=quant_info.block_shape,
)
return StandardCombineInput(
hidden_states=output,
)
@register_pre_permute("standard", "triton")
def pre_permute_standard_to_triton(
dispatch_output: StandardDispatchOutput,
quant_info: TritonMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> TritonRunnerInput:
# NOTE: this is dead code as a fused func for standard format is registered.
# This is left here for testing and examples.
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
get_config_dtype_str,
moe_align_block_size,
try_get_optimal_moe_config,
)
from sglang.srt.layers.moe.topk import TopKOutputChecker
hidden_states, topk_output = dispatch_output
assert TopKOutputChecker.format_is_standard(topk_output)
num_tokens = hidden_states.shape[0]
num_local_experts = runner_config.num_local_experts
if (
not (quant_info.use_fp8_w8a8 or quant_info.use_int8_w8a8)
or quant_info.block_shape is not None
or _use_aiter
):
padding_size = 0
else:
padding_size = _MOE_PADDING_SIZE
config_dtype = get_config_dtype_str(
use_fp8_w8a8=quant_info.use_fp8_w8a8,
use_int8_w8a8=quant_info.use_int8_w8a8,
use_int8_w8a16=quant_info.use_int8_w8a16,
use_int4_w4a16=quant_info.use_int4_w4a16,
dtype=hidden_states.dtype,
)
get_config_func = functools.partial(
try_get_optimal_moe_config,
quant_info.w13_weight.shape,
(
num_local_experts,
quant_info.w2_weight.shape[1],
quant_info.w2_weight.shape[2] - padding_size,
),
topk_output.topk_ids.shape[1],
config_dtype,
block_shape=quant_info.block_shape,
)
config = get_config_func(num_tokens)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_output.topk_ids, config["BLOCK_SIZE_M"], num_local_experts
)
running_state["config"] = config
return TritonRunnerInput(
hidden_states=hidden_states,
topk_weights=topk_output.topk_weights,
topk_ids=topk_output.topk_ids,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,
)
@register_post_permute("triton", "standard")
def post_permute_triton_to_standard(
runner_output: TritonRunnerOutput,
quant_info: TritonMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> StandardCombineInput:
# NOTE: this is dead code as a fused func for standard format is registered.
# This is left here for testing and examples.
return StandardCombineInput(
hidden_states=runner_output.hidden_states,
)
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher, BaseDispatcher,
BaseDispatcherConfig, BaseDispatcherConfig,
CombineInput,
CombineInputChecker,
CombineInputFormat,
DispatchOutput, DispatchOutput,
DispatchOutputChecker, DispatchOutputChecker,
DispatchOutputFormat, DispatchOutputFormat,
...@@ -9,21 +12,32 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import ( ...@@ -9,21 +12,32 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
AscendDeepEPLLOutput, AscendDeepEPLLOutput,
DeepEPConfig, DeepEPConfig,
DeepEPDispatcher, DeepEPDispatcher,
DeepEPLLCombineInput,
DeepEPLLOutput, DeepEPLLOutput,
DeepEPNormalCombineInput,
DeepEPNormalOutput, DeepEPNormalOutput,
) )
from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
)
__all__ = [ __all__ = [
"AscendDeepEPLLOutput", "AscendDeepEPLLOutput",
"BaseDispatcher", "BaseDispatcher",
"BaseDispatcherConfig", "BaseDispatcherConfig",
"CombineInput",
"CombineInputChecker",
"CombineInputFormat",
"DispatchOutput", "DispatchOutput",
"DispatchOutputFormat", "DispatchOutputFormat",
"DispatchOutputChecker", "DispatchOutputChecker",
"StandardDispatchOutput", "StandardDispatchOutput",
"StandardCombineInput",
"DeepEPConfig", "DeepEPConfig",
"DeepEPDispatcher", "DeepEPDispatcher",
"DeepEPNormalOutput", "DeepEPNormalOutput",
"DeepEPLLOutput", "DeepEPLLOutput",
"DeepEPLLCombineInput",
"DeepEPNormalCombineInput",
] ]
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum, auto from enum import Enum
from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable
import torch import torch
...@@ -9,10 +9,16 @@ import torch ...@@ -9,10 +9,16 @@ import torch
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import ( from sglang.srt.layers.moe.token_dispatcher import (
AscendDeepEPLLOutput, AscendDeepEPLLOutput,
DeepEPLLCombineInput,
DeepEPLLOutput, DeepEPLLOutput,
DeepEPNormalCombineInput,
DeepEPNormalOutput, DeepEPNormalOutput,
StandardCombineInput,
StandardDispatchOutput, StandardDispatchOutput,
) )
from sglang.srt.layers.moe.topk import TopKOutput
# ------------------------------ Dispatch Output -------------------------------------
class DispatchOutputChecker: class DispatchOutputChecker:
...@@ -50,10 +56,10 @@ class DispatchOutputChecker: ...@@ -50,10 +56,10 @@ class DispatchOutputChecker:
class DispatchOutputFormat(Enum): class DispatchOutputFormat(Enum):
STANDARD = auto() STANDARD = "standard"
DEEPEP_NORMAL = auto() DEEPEP_NORMAL = "deepep_normal"
DEEPEP_LL = auto() DEEPEP_LL = "deepep_ll"
ASCENT_LL = auto() ASCENT_LL = "ascent_ll"
def is_standard(self) -> bool: def is_standard(self) -> bool:
return self == DispatchOutputFormat.STANDARD return self == DispatchOutputFormat.STANDARD
...@@ -78,10 +84,63 @@ class DispatchOutputFormat(Enum): ...@@ -78,10 +84,63 @@ class DispatchOutputFormat(Enum):
class DispatchOutput(Protocol): class DispatchOutput(Protocol):
"""Protocol for dispatch outputs in different formats.""" """Protocol for dispatch outputs in different formats."""
# TODO: add hidden_states to the protocol
@property @property
def format(self) -> DispatchOutputFormat: ... def format(self) -> DispatchOutputFormat: ...
# ------------------------------ Combine Input -------------------------------------
class CombineInputChecker:
@staticmethod
def format_is_standard(
combine_input: CombineInput,
) -> TypeGuard[StandardCombineInput]:
return combine_input.format == CombineInputFormat.STANDARD
@staticmethod
def format_is_deepep_normal(
combine_input: CombineInput,
) -> TypeGuard[DeepEPNormalCombineInput]:
return combine_input.format == CombineInputFormat.DEEPEP_NORMAL
@staticmethod
def format_is_deepep_ll(
combine_input: CombineInput,
) -> TypeGuard[DeepEPLLCombineInput]:
return combine_input.format == CombineInputFormat.DEEPEP_LL
@staticmethod
def format_is_deepep(
combine_input: CombineInput,
) -> TypeGuard[Union[DeepEPNormalCombineInput, DeepEPLLCombineInput]]:
return combine_input.format in [
CombineInputFormat.DEEPEP_NORMAL,
CombineInputFormat.DEEPEP_LL,
]
class CombineInputFormat(Enum):
STANDARD = "standard"
DEEPEP_NORMAL = "deepep_normal"
DEEPEP_LL = "deepep_ll"
@runtime_checkable
class CombineInput(Protocol):
"""Protocol for combine inputs in different formats."""
# TODO: add hidden_states to the protocol
@property
def format(self) -> CombineInputFormat: ...
# ------------------------------ Base Dispatcher -------------------------------------
class BaseDispatcherConfig(ABC): class BaseDispatcherConfig(ABC):
"""Base class for dispatcher configs.""" """Base class for dispatcher configs."""
...@@ -92,9 +151,11 @@ class BaseDispatcher(ABC): ...@@ -92,9 +151,11 @@ class BaseDispatcher(ABC):
"""Base class for dispatchers.""" """Base class for dispatchers."""
@abstractmethod @abstractmethod
def dispatch(self, *args, **kwargs) -> DispatchOutput: def dispatch(
self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs
) -> DispatchOutput:
pass pass
@abstractmethod @abstractmethod
def combine(self, *args, **kwargs) -> torch.Tensor: def combine(self, combine_input: CombineInput, **kwargs) -> torch.Tensor:
pass pass
...@@ -5,13 +5,15 @@ from dataclasses import dataclass ...@@ -5,13 +5,15 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled from sglang.srt.layers.moe.token_dispatcher.base import (
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher, BaseDispatcher,
BaseDispatcherConfig, BaseDispatcherConfig,
CombineInput,
CombineInputFormat,
DispatchOutput, DispatchOutput,
DispatchOutputFormat, DispatchOutputFormat,
) )
from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
...@@ -56,6 +58,7 @@ class DeepEPNormalOutput(NamedTuple): ...@@ -56,6 +58,7 @@ class DeepEPNormalOutput(NamedTuple):
"""DeepEP normal dispatch output.""" """DeepEP normal dispatch output."""
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor] hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
# hidden_states_scale
topk_idx: torch.Tensor topk_idx: torch.Tensor
topk_weights: torch.Tensor topk_weights: torch.Tensor
num_recv_tokens_per_expert: List[int] num_recv_tokens_per_expert: List[int]
...@@ -99,6 +102,30 @@ assert isinstance(DeepEPLLOutput, DispatchOutput) ...@@ -99,6 +102,30 @@ assert isinstance(DeepEPLLOutput, DispatchOutput)
assert isinstance(AscendDeepEPLLOutput, DispatchOutput) assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
class DeepEPNormalCombineInput(NamedTuple):
"""DeepEP normal combine input."""
pass
@property
def format(self) -> CombineInputFormat:
return CombineInputFormat.DEEPEP_NORMAL
class DeepEPLLCombineInput(NamedTuple):
"""DeepEP low latency combine input."""
pass
@property
def format(self) -> CombineInputFormat:
return CombineInputFormat.DEEPEP_LL
assert isinstance(DeepEPNormalCombineInput, CombineInput)
assert isinstance(DeepEPLLCombineInput, CombineInput)
class DeepEPDispatchMode(IntEnum): class DeepEPDispatchMode(IntEnum):
NORMAL = auto() NORMAL = auto()
LOW_LATENCY = auto() LOW_LATENCY = auto()
......
from __future__ import annotations from __future__ import annotations
from typing import NamedTuple from typing import TYPE_CHECKING, NamedTuple
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( import torch
from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher,
CombineInput,
CombineInputFormat,
DispatchOutput, DispatchOutput,
DispatchOutputFormat, DispatchOutputFormat,
) )
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
class StandardDispatchOutput(NamedTuple): class StandardDispatchOutput(NamedTuple):
"""Standard dispatch output.""" """Standard dispatch output."""
hidden_states: torch.Tensor
topk_output: TopKOutput
@property @property
def format(self) -> DispatchOutputFormat: def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.STANDARD return DispatchOutputFormat.STANDARD
assert isinstance(StandardDispatchOutput, DispatchOutput) assert isinstance(StandardDispatchOutput, DispatchOutput)
class StandardCombineInput(NamedTuple):
"""Standard combine input."""
hidden_states: torch.Tensor
@property
def format(self) -> CombineInputFormat:
return CombineInputFormat.STANDARD
assert isinstance(StandardCombineInput, CombineInput)
class StandardDispatcher(BaseDispatcher):
def dispatch(
self, hidden_states: torch.Tensor, topk_output: TopKOutput
) -> DispatchOutput:
return StandardDispatchOutput(
hidden_states=hidden_states, topk_output=topk_output
)
def combine(self, combine_input: CombineInput) -> torch.Tensor:
if isinstance(combine_input, StandardCombineInput):
return combine_input.hidden_states
else:
# TODO: this branch should be removed in the future
assert isinstance(combine_input, torch.Tensor)
return combine_input
from __future__ import annotations from __future__ import annotations
import importlib.util import importlib.util
import logging
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
...@@ -12,11 +13,12 @@ from sglang.srt.layers.dp_attention import ( ...@@ -12,11 +13,12 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size, get_attention_dp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.utils import logger
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
class MoeA2ABackend(Enum): class MoeA2ABackend(Enum):
...@@ -131,7 +133,7 @@ def get_moe_a2a_backend() -> MoeA2ABackend: ...@@ -131,7 +133,7 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
global MOE_A2A_BACKEND global MOE_A2A_BACKEND
if MOE_A2A_BACKEND is None: if MOE_A2A_BACKEND is None:
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend") logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
MOE_A2A_BACKEND = MoeA2ABackend(None) MOE_A2A_BACKEND = MoeA2ABackend.NONE
return MOE_A2A_BACKEND return MOE_A2A_BACKEND
...@@ -139,7 +141,7 @@ def get_moe_runner_backend() -> MoeRunnerBackend: ...@@ -139,7 +141,7 @@ def get_moe_runner_backend() -> MoeRunnerBackend:
global MOE_RUNNER_BACKEND global MOE_RUNNER_BACKEND
if MOE_RUNNER_BACKEND is None: if MOE_RUNNER_BACKEND is None:
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend") logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
MOE_RUNNER_BACKEND = MoeRunnerBackend("triton") MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
return MOE_RUNNER_BACKEND return MOE_RUNNER_BACKEND
...@@ -147,7 +149,7 @@ def get_deepep_mode() -> DeepEPMode: ...@@ -147,7 +149,7 @@ def get_deepep_mode() -> DeepEPMode:
global DEEPEP_MODE global DEEPEP_MODE
if DEEPEP_MODE is None: if DEEPEP_MODE is None:
logger.warning("DEEPEP_MODE is not initialized, using auto mode") logger.warning("DEEPEP_MODE is not initialized, using auto mode")
DEEPEP_MODE = DeepEPMode("auto") DEEPEP_MODE = DeepEPMode.AUTO
return DEEPEP_MODE return DEEPEP_MODE
......
...@@ -34,7 +34,10 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param ...@@ -34,7 +34,10 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput from sglang.srt.layers.moe.token_dispatcher import (
StandardDispatchOutput,
CombineInput,
)
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip
...@@ -736,24 +739,32 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -736,24 +739,32 @@ class AWQMoEMethod(FusedMoEMethodBase):
) )
replace_parameter(layer, "w2_qzeros", marlin_w2_zp) replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: StandardTopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig, from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
) -> torch.Tensor:
assert ( assert (
moe_runner_config.activation == "silu" self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported." ), "Only SiLU activation is supported."
# The input must currently be float16 # The input must currently be float16
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.half() x = x.half()
topk_weights, topk_ids, router_logits = topk_output topk_weights, topk_ids, router_logits = topk_output
return fused_marlin_moe( output = fused_marlin_moe(
x, x,
layer.w13_qweight, layer.w13_qweight,
layer.w2_qweight, layer.w2_qweight,
...@@ -768,3 +779,4 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -768,3 +779,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
w2_zeros=layer.w2_qzeros, w2_zeros=layer.w2_qzeros,
num_bits=self.quant_config.weight_bits, num_bits=self.quant_config.weight_bits,
).to(orig_dtype) ).to(orig_dtype)
return StandardCombineInput(hidden_states=output)
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
import torch import torch
...@@ -10,7 +11,7 @@ from torch import nn ...@@ -10,7 +11,7 @@ from torch import nn
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
class QuantizeMethodBase(ABC): class QuantizeMethodBase(ABC):
...@@ -89,20 +90,24 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -89,20 +90,24 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
raise NotImplementedError raise NotImplementedError
@abstractmethod
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
raise NotImplementedError
@abstractmethod @abstractmethod
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: DispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
......
...@@ -9,6 +9,8 @@ import torch ...@@ -9,6 +9,8 @@ import torch
from torch.nn import Module from torch.nn import Module
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -22,8 +24,10 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped ...@@ -22,8 +24,10 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher import (
from sglang.srt.layers.moe.topk import TopKOutput CombineInput,
StandardDispatchOutput,
)
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
...@@ -257,7 +261,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): ...@@ -257,7 +261,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
layer: Module, layer: Module,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
...@@ -273,25 +277,28 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): ...@@ -273,25 +277,28 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
) )
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by column parallel or enabling merged weights # Required by column parallel or enabling merged weights
if intermediate_size % block_n != 0: if intermediate_size_per_partition % block_n != 0:
raise ValueError( raise ValueError(
f"The output_size of gate's and up's weight = " f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by " f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}." f"weight quantization block_n = {block_n}."
) )
if tp_size > 1: if tp_size > 1:
# Required by row parallel # Required by row parallel
if intermediate_size % block_k != 0: if intermediate_size_per_partition % block_k != 0:
raise ValueError( raise ValueError(
f"The input_size of down's weight = " f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by " f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}." f"weight quantization block_k = {block_k}."
) )
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
) )
...@@ -300,7 +307,10 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): ...@@ -300,7 +307,10 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
) )
...@@ -311,7 +321,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): ...@@ -311,7 +321,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.ones( torch.ones(
num_experts, num_experts,
2 * ((intermediate_size + block_n - 1) // block_n), 2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k, (hidden_size + block_k - 1) // block_k,
dtype=torch.float32, dtype=torch.float32,
), ),
...@@ -321,7 +331,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): ...@@ -321,7 +331,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
torch.ones( torch.ones(
num_experts, num_experts,
(hidden_size + block_n - 1) // block_n, (hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k, (intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32, dtype=torch.float32,
), ),
requires_grad=False, requires_grad=False,
...@@ -344,26 +354,27 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): ...@@ -344,26 +354,27 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
# Block quant doesn't need to process weights after loading # Block quant doesn't need to process weights after loading
return return
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: quant_info = TritonMoeQuantInfo(
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
# Expert fusion with INT8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
use_int8_w8a8=True, use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale_inv), w13_scale=layer.w13_weight_scale_inv,
w2_scale=(layer.w2_weight_scale_inv), w2_scale=layer.w2_weight_scale_inv,
a1_scale=layer.w13_input_scale, a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
) )
return self.runner.run(dispatch_output, quant_info)
...@@ -11,6 +11,8 @@ import torch ...@@ -11,6 +11,8 @@ import torch
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
...@@ -30,8 +32,10 @@ from sglang.srt.utils import ( ...@@ -30,8 +32,10 @@ from sglang.srt.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher import (
from sglang.srt.layers.moe.topk import TopKOutput CombineInput,
StandardDispatchOutput,
)
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig, CompressedTensorsConfig,
) )
...@@ -293,14 +297,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -293,14 +297,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
if ( if (
_use_aiter _use_aiter
...@@ -308,7 +322,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -308,7 +322,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
and moe_runner_config.apply_router_weight_on_input and moe_runner_config.apply_router_weight_on_input
): ):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
return rocm_fused_experts_tkw1( output = rocm_fused_experts_tkw1(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
...@@ -324,21 +338,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -324,21 +338,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
) )
return StandardCombineInput(hidden_states=output)
else: else:
return fused_experts( quant_info = TritonMoeQuantInfo(
x, w13_weight=layer.w13_weight,
layer.w13_weight, w2_weight=layer.w2_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True, use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy per_channel_quant=self.weight_quant.strategy
== QuantizationStrategy.CHANNEL, == QuantizationStrategy.CHANNEL,
w1_scale=layer.w13_weight_scale, w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
) )
return self.runner.run(dispatch_output, quant_info)
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...@@ -380,8 +393,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -380,8 +393,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
params_dtype == torch.float16 params_dtype == torch.float16
), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501 ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
# Will transpose the loaded weight along the # Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will # intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims # shard for TP along the transposed dims
...@@ -415,13 +426,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -415,13 +426,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
# In the case where we have actorder/g_idx, # In the case where we have actorder/g_idx,
# we do not partition the w2 scales # we do not partition the w2 scales
load_full_w2 = self.actorder and self.group_size != -1 load_full_w2 = self.actorder and self.group_size != -1
w2_scales_size = (
intermediate_size_full if load_full_w2 else intermediate_size_per_partition
)
self.is_k_full = (not self.actorder) or ( if load_full_w2:
intermediate_size_per_partition == intermediate_size_full w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
) else:
w2_scales_size = intermediate_size_per_partition
self.is_k_full = (not self.actorder) or layer.moe_tp_size == 1
if self.strategy == "channel": if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1 num_groups_w2 = num_groups_w13 = 1
...@@ -640,21 +651,29 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -640,21 +651,29 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
) )
replace_tensor("w2_weight_scale", marlin_w2_scales) replace_tensor("w2_weight_scale", marlin_w2_scales)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
assert ( assert (
moe_runner_config.activation == "silu" self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported." ), "Only SiLU activation is supported."
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, router_logits = topk_output topk_weights, topk_ids, router_logits = topk_output
return torch.ops.vllm.fused_marlin_moe( output = torch.ops.vllm.fused_marlin_moe(
x, x,
layer.w13_weight_packed, layer.w13_weight_packed,
layer.w2_weight_packed, layer.w2_weight_packed,
...@@ -670,3 +689,4 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -670,3 +689,4 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
num_bits=self.num_bits, num_bits=self.num_bits,
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
) )
return StandardCombineInput(hidden_states=output)
...@@ -30,6 +30,9 @@ except ImportError: ...@@ -30,6 +30,9 @@ except ImportError:
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
BlockQuantScaleParameter, BlockQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
...@@ -81,7 +84,11 @@ from sglang.srt.utils import ( ...@@ -81,7 +84,11 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
DispatchOutput,
StandardDispatchOutput,
)
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
...@@ -527,7 +534,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -527,7 +534,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: Module, layer: Module,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
...@@ -543,18 +550,18 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -543,18 +550,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by column parallel or enabling merged weights # Required by column parallel or enabling merged weights
if intermediate_size % block_n != 0: if intermediate_size_per_partition % block_n != 0:
raise ValueError( raise ValueError(
f"The output_size of gate's and up's weight = " f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by " f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}." f"weight quantization block_n = {block_n}."
) )
if tp_size > 1: if tp_size > 1:
# Required by row parallel # Required by row parallel
if intermediate_size % block_k != 0: if intermediate_size_per_partition % block_k != 0:
raise ValueError( raise ValueError(
f"The input_size of down's weight = " f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by " f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}." f"weight quantization block_k = {block_k}."
) )
...@@ -564,7 +571,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -564,7 +571,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
2 * intermediate_size, 2 * intermediate_size_per_partition,
hidden_size // 8, hidden_size // 8,
dtype=params_dtype, dtype=params_dtype,
), ),
...@@ -572,20 +579,29 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -572,20 +579,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype num_experts,
hidden_size,
intermediate_size_per_partition // 8,
dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
) )
else: else:
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
) )
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
) )
...@@ -601,7 +617,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -601,7 +617,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.ones( torch.ones(
num_experts, num_experts,
2 * ((intermediate_size + block_n - 1) // block_n), 2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k, (hidden_size + block_k - 1) // block_k,
dtype=torch.float32, dtype=torch.float32,
), ),
...@@ -611,7 +627,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -611,7 +627,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
torch.ones( torch.ones(
num_experts, num_experts,
(hidden_size + block_n - 1) // block_n, (hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k, (intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32, dtype=torch.float32,
), ),
requires_grad=False, requires_grad=False,
...@@ -632,19 +648,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -632,19 +648,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
self.c_strides1 = torch.full( self.c_strides1 = torch.full(
(num_experts,), (num_experts,),
2 * intermediate_size, 2 * intermediate_size_per_partition,
device=w13_weight.device, device=w13_weight.device,
dtype=torch.int64, dtype=torch.int64,
) )
self.ab_strides2 = torch.full( self.ab_strides2 = torch.full(
(num_experts,), (num_experts,),
intermediate_size, intermediate_size_per_partition,
device=w2_weight.device, device=w2_weight.device,
dtype=torch.int64, dtype=torch.int64,
) )
self.c_strides2 = torch.full( self.c_strides2 = torch.full(
(num_experts,), (num_experts,),
hidden_size, intermediate_size_per_partition,
device=w2_weight.device, device=w2_weight.device,
dtype=torch.int64, dtype=torch.int64,
) )
...@@ -691,7 +707,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -691,7 +707,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if _is_hip: # _use_aiter: TODO: add check back after triton kernel if _is_hip: # _use_aiter: TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter( w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32), torch.ones(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32,
),
requires_grad=False, requires_grad=False,
) )
w2_weight_scale1 = torch.nn.Parameter( w2_weight_scale1 = torch.nn.Parameter(
...@@ -984,14 +1004,23 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -984,14 +1004,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: DispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
...@@ -1001,7 +1030,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1001,7 +1030,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
moe_runner_config.apply_router_weight_on_input, topk_weights, x moe_runner_config.apply_router_weight_on_input, topk_weights, x
) )
return torch.ops.sgl_kernel.fused_experts_cpu( output = torch.ops.sgl_kernel.fused_experts_cpu(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1017,6 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1017,6 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
None, # a2_scale None, # a2_scale
True, # is_vnni True, # is_vnni
) )
return StandardCombineInput(hidden_states=output)
if _is_hip: if _is_hip:
ret = self.maybe_apply_hip_fused_experts( ret = self.maybe_apply_hip_fused_experts(
...@@ -1027,7 +1057,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1027,7 +1057,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
moe_runner_config.no_combine, moe_runner_config.no_combine,
) )
if ret is not None: if ret is not None:
return ret return StandardCombineInput(hidden_states=ret)
if self.use_cutlass_fused_experts_fp8: if self.use_cutlass_fused_experts_fp8:
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
...@@ -1056,17 +1086,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1056,17 +1086,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.problem_sizes2, self.problem_sizes2,
use_fp8_blockscale=True, use_fp8_blockscale=True,
) )
# Scale by routed_scaling_factor is fused into select_experts. return StandardCombineInput(hidden_states=output)
return output
# Expert fusion with FP8 quantization quant_info = TritonMoeQuantInfo(
return fused_experts( w13_weight=layer.w13_weight,
x, w2_weight=layer.w2_weight,
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale=( w13_scale=(
layer.w13_weight_scale_inv layer.w13_weight_scale_inv
if self.block_quant if self.block_quant
else layer.w13_weight_scale else layer.w13_weight_scale
...@@ -1074,20 +1100,22 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1074,20 +1100,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_scale=( w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
), ),
a1_scale=layer.w13_input_scale, a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
) )
return self.runner.run(dispatch_output, quant_info)
def apply_with_router_logits( def apply_with_router_logits(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: ) -> torch.Tensor:
activation = moe_runner_config.activation x = dispatch_output.hidden_states
routed_scaling_factor = moe_runner_config.routed_scaling_factor topk_output = dispatch_output.topk_output
activation = self.moe_runner_config.activation
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment