Unverified Commit 3fa62da7 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[7/N] MoE Refactor: the implementation of new framework (#9269)

parent dbb1235d
......@@ -11,6 +11,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
import logging
import math
import os
......@@ -19,17 +22,19 @@ from abc import ABC
from collections import deque
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
import einops
import torch
import torch.distributed
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var
if TYPE_CHECKING:
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
logger = logging.getLogger(__name__)
# --------------------------------------- Entrypoint -----------------------------------------
......@@ -43,7 +48,7 @@ class ExpertDistributionRecorder(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
expert_location_metadata: ExpertLocationMetadata,
rank: int,
):
if server_args.expert_distribution_recorder_mode is not None:
......@@ -118,7 +123,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
expert_location_metadata: ExpertLocationMetadata,
rank: int,
):
self._server_args = server_args
......@@ -279,7 +284,7 @@ class _SinglePassGatherer(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
expert_location_metadata: ExpertLocationMetadata,
rank: int,
) -> "_SinglePassGatherer":
if server_args.expert_distribution_recorder_mode == "per_token":
......@@ -307,7 +312,7 @@ class _SinglePassGatherer(ABC):
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
def __init__(self, expert_location_metadata: ExpertLocationMetadata, rank: int):
self._expert_location_metadata = expert_location_metadata
self._rank = rank
......@@ -346,7 +351,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
expert_location_metadata: ExpertLocationMetadata,
rank: int,
):
super().__init__(expert_location_metadata, rank)
......@@ -561,7 +566,7 @@ class _Accumulator(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
expert_location_metadata: ExpertLocationMetadata,
rank: int,
) -> "_Accumulator":
return _Accumulator.get_class(server_args)(
......@@ -580,7 +585,7 @@ class _Accumulator(ABC):
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
expert_location_metadata: ExpertLocationMetadata,
rank: int,
):
self._server_args = server_args
......
......@@ -11,21 +11,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
import json
import logging
import random
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
import torch
import torch.distributed
import torch.nn.functional as F
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.eplb import eplb_algorithms
from sglang.srt.model_loader import get_model_architecture
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
......
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig
from sglang.srt.layers.moe.utils import (
DeepEPMode,
MoeA2ABackend,
......@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import (
__all__ = [
"DeepEPMode",
"MoeA2ABackend",
"MoeRunner",
"MoeRunnerConfig",
"MoeRunnerBackend",
"initialize_moe_config",
......
......@@ -8,16 +8,18 @@ from torch.nn import functional as F
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput
from sglang.srt.layers.moe.topk import StandardTopKOutput
def fused_moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
dispatch_output: StandardDispatchOutput,
) -> torch.Tensor:
x, topk_output = dispatch_output
moe_runner_config = layer.moe_runner_config
if moe_runner_config.apply_router_weight_on_input:
raise NotImplementedError()
......
# NOTE: this file will be separated into sglang/srt/layers/moe/moe_runner/triton_utils.py
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py
"""Fused MoE kernel."""
......@@ -6,13 +7,12 @@ from __future__ import annotations
import functools
import os
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
import torch
import triton.language as tl
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.utils import (
cpu_has_amx_support,
direct_register_custom_op,
......@@ -26,6 +26,9 @@ from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_c
from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton
from .moe_align_block_size import moe_align_block_size
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import StandardTopKOutput
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support()
......
......@@ -23,8 +23,13 @@ from sglang.srt.layers.moe import (
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.token_dispatcher.standard import (
CombineInput,
StandardDispatcher,
)
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
......@@ -152,16 +157,6 @@ class FusedMoE(torch.nn.Module):
self.expert_map_cpu = None
self.expert_map_gpu = None
self.moe_runner_config = MoeRunnerConfig(
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
)
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
if enable_flashinfer_cutlass_moe and quant_config is None:
......@@ -196,13 +191,6 @@ class FusedMoE(torch.nn.Module):
self.use_presharded_weights = use_presharded_weights
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels
)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
self.quant_config = quant_config
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
......@@ -213,12 +201,40 @@ class FusedMoE(torch.nn.Module):
and self.use_flashinfer_mxfp4_moe
):
hidden_size = round_up(hidden_size, 256)
self.hidden_size = hidden_size
self.moe_runner_config = MoeRunnerConfig(
num_experts=num_experts,
num_local_experts=self.num_local_experts,
hidden_size=hidden_size,
intermediate_size_per_partition=self.intermediate_size_per_partition,
layer_id=layer_id,
top_k=top_k,
num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
)
if quant_config is None:
self.quant_method: FusedMoEMethodBase = UnquantizedFusedMoEMethod(
self.use_triton_kernels
)
else:
self.quant_method: FusedMoEMethodBase = quant_config.get_quant_method(
self, prefix
)
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
num_experts=self.num_local_experts,
hidden_size=hidden_size,
# FIXME: figure out which intermediate_size to use
intermediate_size=self.intermediate_size_per_partition,
intermediate_size_per_partition=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=(
......@@ -229,6 +245,9 @@ class FusedMoE(torch.nn.Module):
with_bias=with_bias,
)
self.quant_method.create_moe_runner(self, self.moe_runner_config)
self.dispatcher = StandardDispatcher()
def _load_per_tensor_weight_scale(
self,
shard_id: str,
......@@ -811,16 +830,17 @@ class FusedMoE(torch.nn.Module):
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
raise NotImplementedError()
# Matrix multiply.
with use_symmetric_memory(get_tp_group()) as sm:
dispatch_output = self.dispatcher.dispatch(
hidden_states=hidden_states, topk_output=topk_output
)
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
topk_output=topk_output,
moe_runner_config=self.moe_runner_config,
)
sm.tag(final_hidden_states)
# TODO: consider using symmetric memory
combine_input = self.quant_method.apply(
layer=self,
dispatch_output=dispatch_output,
)
final_hidden_states = self.dispatcher.combine(combine_input)
final_hidden_states = final_hidden_states[
..., :origin_hidden_states_dim
......@@ -955,7 +975,6 @@ class FlashInferFusedMoE(FusedMoE):
layer=self,
x=hidden_states,
topk_output=topk_output,
moe_runner_config=self.moe_runner_config,
)
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
......
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.runner import MoeRunner
__all__ = ["MoeRunnerConfig"]
__all__ = ["MoeRunnerConfig", "MoeRunner"]
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
import torch
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
CombineInputFormat,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner.triton import (
TritonRunnerCore,
TritonRunnerInput,
TritonRunnerOutput,
)
@dataclass
class MoeRunnerConfig:
# MoE parameters
num_experts: Optional[int] = None
num_local_experts: Optional[int] = None
hidden_size: Optional[int] = None
intermediate_size_per_partition: Optional[int] = None
layer_id: Optional[int] = None
top_k: Optional[int] = None
num_fused_shared_experts: Optional[int] = None
params_dtype: Optional[torch.dtype] = None
# Runner configuration
activation: str = "silu"
apply_router_weight_on_input: bool = False
inplace: bool = True
......@@ -11,3 +43,254 @@ class MoeRunnerConfig:
routed_scaling_factor: Optional[float] = None
gemm1_alpha: Optional[float] = None
gemm1_clamp_limit: Optional[float] = None
@dataclass
class RunnerInput(ABC):
@property
@abstractmethod
def runner_backend(self) -> MoeRunnerBackend: ...
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerInput]:
return self.runner_backend == MoeRunnerBackend.TRITON
class RunnerOutput(ABC):
@property
@abstractmethod
def runner_backend(self) -> MoeRunnerBackend: ...
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerOutput]:
return self.runner_backend == MoeRunnerBackend.TRITON
@dataclass
class MoeQuantInfo(ABC):
"""Moe quantization data."""
pass
class MoeRunnerCore(ABC):
def __init__(self, config: MoeRunnerConfig):
self.config = config
@abstractmethod
def run(
self, runner_input: RunnerInput, quant_info: MoeQuantInfo, running_state: dict
) -> RunnerOutput:
pass
@property
@abstractmethod
def runner_backend(self) -> MoeRunnerBackend: ...
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerCore]:
return self.runner_backend == MoeRunnerBackend.TRITON
class FusedOpPool:
_fused_funcs: dict[str, Callable] = {}
@classmethod
def register_fused_func(
cls, a2a_backend_name: str, runner_backend_name: str, fused_func: Callable
):
key = (a2a_backend_name, runner_backend_name)
if key in cls._fused_funcs:
raise ValueError(
f"Fused function for {a2a_backend_name} to {runner_backend_name} is already registered."
)
assert MoeA2ABackend(
a2a_backend_name
), f"Invalid dispatch name: {a2a_backend_name}"
assert MoeRunnerBackend(
runner_backend_name
), f"Invalid runner name: {runner_backend_name}"
cls._fused_funcs[key] = fused_func
@classmethod
def get_fused_func(cls, dispatch_name: str, runner_name: str) -> Optional[Callable]:
key = (dispatch_name, runner_name)
fused_func = cls._fused_funcs.get(key)
return fused_func
class PermuteMethodPool:
_pre_permute_methods: dict[
Tuple[DispatchOutputFormat, MoeRunnerBackend], Callable
] = {}
_post_permute_methods: dict[
Tuple[MoeRunnerBackend, CombineInputFormat], Callable
] = {}
@classmethod
def register_pre_permute(
cls,
dispatch_output_name: str,
runner_backend_name: str,
permute_func: Callable,
):
"""
Register a customized pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
:param dispatch_output_name: The DispatchOutputFormat name.
:param runner_backend_name: The MoeRunnerBackend name.
:param permute_func: The permute function to register.
"""
key = (dispatch_output_name, runner_backend_name)
if key in cls._pre_permute_methods:
raise ValueError(
f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
)
assert DispatchOutputFormat(
dispatch_output_name
), f"Invalid dispatch output name: {dispatch_output_name}"
assert MoeRunnerBackend(
runner_backend_name
), f"Invalid runner backend name: {runner_backend_name}"
cls._pre_permute_methods[key] = permute_func
@classmethod
def register_post_permute(
cls,
runner_backend_name: str,
combine_input_name: str,
permute_func: Callable,
):
"""
Register a customized post-permute function for the given MoeRunnerBackend and CombineInputFormat.
:param runner_backend_name: The MoeRunnerBackend name.
:param combine_input_name: The CombineInputFormat name.
:param permute_func: The permute function to register.
"""
key = (runner_backend_name, combine_input_name)
if key in cls._post_permute_methods:
raise ValueError(
f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
)
assert MoeRunnerBackend(
runner_backend_name
), f"Invalid runner backend name: {runner_backend_name}"
assert CombineInputFormat(
combine_input_name
), f"Invalid combine input name: {combine_input_name}"
cls._post_permute_methods[key] = permute_func
@classmethod
def get_pre_permute(
cls,
dispatch_output_format: DispatchOutputFormat,
runner_input_format: MoeRunnerBackend,
) -> Callable:
"""
Retrieve the pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
:param dispatch_output_format: The DispatchOutputFormat type.
:param runner_input_format: The MoeRunnerBackend type.
:return: The registered permute function or None if not found.
"""
key = (dispatch_output_format, runner_input_format)
pre_permute_func = cls._pre_permute_methods.get(key)
assert (
pre_permute_func is not None
), f"Pre-permute function for {dispatch_output_format} to {runner_input_format} is not registered"
return pre_permute_func
@classmethod
def get_post_permute(
cls,
runner_output_format: MoeRunnerBackend,
combine_input_format: CombineInputFormat,
) -> Callable:
"""
Retrieve the post-permute function for the given MoeRunnerBackend and CombineInputFormat.
:param runner_output_format: The MoeRunnerBackend type.
:param combine_input_format: The CombineInputFormat type.
:return: The registered permute function or None if not found.
"""
key = (runner_output_format, combine_input_format)
post_permute_func = cls._post_permute_methods.get(key)
assert (
post_permute_func is not None
), f"Post-permute function for {runner_output_format} to {combine_input_format} is not registered"
return post_permute_func
def register_fused_func(
a2a_backend_name: str,
runner_backend_name: str,
) -> Callable:
"""
Decorator to register a fused function for the given DispatchOutputFormat and MoeRunnerBackend.
:param a2a_backend_name: The A2A backend name.
:param runner_backend_name: The MoeRunnerBackend name.
:return: The decorator function.
"""
def decorator(fused_func: Callable):
FusedOpPool.register_fused_func(
a2a_backend_name, runner_backend_name, fused_func
)
return fused_func
return decorator
def register_pre_permute(
dispatch_output_name: str,
runner_backend_name: str,
) -> Callable:
"""
Decorator to register a pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
:param dispatch_output_name: The DispatchOutputFormat name.
:param runner_backend_name: The MoeRunnerBackend name.
:return: The decorator function.
"""
def decorator(
permute_func: Callable[
[DispatchOutput, MoeQuantInfo, MoeRunnerConfig, dict], RunnerInput
]
) -> Callable:
PermuteMethodPool.register_pre_permute(
dispatch_output_name, runner_backend_name, permute_func
)
return permute_func
return decorator
def register_post_permute(
runner_backend_name: str,
combine_input_name: str,
) -> Callable:
"""
Decorator to register a post-permute function for the given MoeRunnerBackend and CombineInputFormat.
:param runner_backend_name: The MoeRunnerBackend name.
:param combine_input_name: The CombineInputFormat name.
:return: The decorator function.
"""
def decorator(
permute_func: Callable[
[RunnerOutput, MoeQuantInfo, MoeRunnerConfig, dict], CombineInput
]
) -> Callable:
PermuteMethodPool.register_post_permute(
runner_backend_name, combine_input_name, permute_func
)
return permute_func
return decorator
from __future__ import annotations
import logging
import os
from typing import TYPE_CHECKING
from sglang.srt.layers.moe.moe_runner.base import (
FusedOpPool,
MoeRunnerConfig,
PermuteMethodPool,
)
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
from sglang.srt.layers.moe.token_dispatcher.base import (
CombineInput,
CombineInputFormat,
DispatchOutput,
)
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
from sglang.srt.layers.moe.utils import MoeRunnerBackend
logger = logging.getLogger(__name__)
class MoeRunner:
def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
self.runner_backend = runner_backend
self.config = config
self.fused_func = None
if runner_backend.is_triton():
self.runner_core = TritonRunnerCore(config)
else:
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
a2a_backend_name = get_moe_a2a_backend().value
runner_backend_name = runner_backend.value
self.fused_func = FusedOpPool.get_fused_func(
a2a_backend_name, runner_backend_name
)
SGLANG_CI_DISABLE_MOE_FUSED_FUNC = os.environ.get(
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC", "0"
)
if SGLANG_CI_DISABLE_MOE_FUSED_FUNC == "1":
logger.info(
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC is set to 1, disabling fused func"
)
self.fused_func = None
def run(
self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo
) -> CombineInput:
if self.fused_func is not None:
return self.fused_func(dispatch_output, quant_info, self.config)
dispatch_format = dispatch_output.format.value
runner_format = self.runner_core.runner_backend.value
self.pre_permute_func = PermuteMethodPool.get_pre_permute(
dispatch_format, runner_format
)
running_state = {}
runner_input = self.pre_permute_func(
dispatch_output, quant_info, self.config, running_state
)
runner_output = self.runner_core.run(runner_input, quant_info, running_state)
runner_format = self.runner_core.runner_backend.value
combine_format = dispatch_output.format.value
self.post_permute_func = PermuteMethodPool.get_post_permute(
runner_format, combine_format
)
combine_input = self.post_permute_func(
runner_output, quant_info, self.config, running_state
)
return combine_input
from __future__ import annotations
import functools
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
import torch
import triton.language as tl
from sglang.srt.layers.moe.moe_runner.base import (
MoeQuantInfo,
MoeRunnerConfig,
MoeRunnerCore,
RunnerInput,
RunnerOutput,
register_fused_func,
register_post_permute,
register_pre_permute,
)
from sglang.srt.layers.moe.token_dispatcher import (
StandardCombineInput,
StandardDispatchOutput,
)
from sglang.srt.layers.moe.utils import MoeRunnerBackend
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_use_aiter = bool(int(os.getenv("SGLANG_MOE_USE_AITER", "0")))
_MOE_PADDING_SIZE = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
elif _is_cpu and _is_cpu_amx_available:
pass
elif _is_hip:
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
if _use_aiter:
try:
from aiter import moe_sum
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
@dataclass
class TritonRunnerInput(RunnerInput):
hidden_states: torch.Tensor
topk_weights: torch.Tensor
topk_ids: torch.Tensor
sorted_token_ids: torch.Tensor
expert_ids: torch.Tensor
num_tokens_post_padded: torch.Tensor
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.TRITON
@dataclass
class TritonRunnerOutput(RunnerOutput):
hidden_states: torch.Tensor
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.TRITON
@dataclass
class TritonMoeQuantInfo(MoeQuantInfo):
w13_weight: torch.Tensor
w2_weight: torch.Tensor
b13: Optional[torch.Tensor] = None
b2: Optional[torch.Tensor] = None
use_fp8_w8a8: bool = False
use_int8_w8a8: bool = False
use_int8_w8a16: bool = False
use_int4_w4a16: bool = False
per_channel_quant: bool = False
w13_scale: Optional[torch.Tensor] = None
w2_scale: Optional[torch.Tensor] = None
w13_zp: Optional[torch.Tensor] = None
w2_zp: Optional[torch.Tensor] = None
a13_scale: Optional[torch.Tensor] = None
a2_scale: Optional[torch.Tensor] = None
block_shape: Optional[List[int]] = None
class TritonRunnerCore(MoeRunnerCore):
def __init__(self, config: MoeRunnerConfig):
super().__init__(config)
def run(
self,
runner_input: TritonRunnerInput,
quant_info: TritonMoeQuantInfo,
running_state: dict,
) -> TritonRunnerOutput:
# TODO: move these functions to the triton runner
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
invoke_fused_moe_kernel,
moe_sum_reduce_torch_compile,
moe_sum_reduce_triton,
swiglu_with_alpha_and_limit,
)
hidden_states = runner_input.hidden_states
topk_weights = runner_input.topk_weights
topk_ids = runner_input.topk_ids
sorted_token_ids = runner_input.sorted_token_ids
expert_ids = runner_input.expert_ids
num_tokens_post_padded = runner_input.num_tokens_post_padded
w13 = quant_info.w13_weight
w2 = quant_info.w2_weight
b13 = quant_info.b13
b2 = quant_info.b2
a13_scale = quant_info.a13_scale
a2_scale = quant_info.a2_scale
w13_scale = quant_info.w13_scale
w2_scale = quant_info.w2_scale
w13_zp = quant_info.w13_zp
w2_zp = quant_info.w2_zp
block_shape = quant_info.block_shape
per_channel_quant = quant_info.per_channel_quant
use_fp8_w8a8 = quant_info.use_fp8_w8a8
use_int8_w8a8 = quant_info.use_int8_w8a8
use_int8_w8a16 = quant_info.use_int8_w8a16
use_int4_w4a16 = quant_info.use_int4_w4a16
activation = self.config.activation
no_combine = self.config.no_combine
inplace = self.config.inplace
gemm1_alpha = self.config.gemm1_alpha
gemm1_limit = self.config.gemm1_clamp_limit
routed_scaling_factor = self.config.routed_scaling_factor
apply_router_weight_on_input = self.config.apply_router_weight_on_input
M = hidden_states.shape[0]
E, N, _ = w13.shape
compute_type = (
tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
)
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
invoke_fused_moe_kernel(
hidden_states,
w13,
b13,
intermediate_cache1,
a13_scale,
w13_scale,
w13_zp,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
apply_router_weight_on_input,
topk_ids.shape[1],
running_state["config"],
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if activation == "silu":
if gemm1_alpha is not None:
assert gemm1_limit is not None
intermediate_cache2 = swiglu_with_alpha_and_limit(
intermediate_cache1.view(-1, N),
gemm1_alpha,
gemm1_limit,
)
elif _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
vllm_ops.silu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
elif activation == "gelu":
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
if _is_cuda:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
vllm_ops.gelu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
else:
raise ValueError(f"Unsupported activation: {activation=}")
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if no_combine:
assert not inplace
out_hidden_states = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
elif inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
b2,
(
intermediate_cache3
if not no_combine and topk_ids.shape[1] != 1
else out_hidden_states.unsqueeze(0)
),
a2_scale,
w2_scale,
w2_zp,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
running_state["config"],
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
)
if routed_scaling_factor is None:
routed_scaling_factor = 1.0
if no_combine:
pass
elif _is_cuda:
if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0:
pass # we write directly into out_hidden_states
elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0:
torch.add(
intermediate_cache3[:, 0],
intermediate_cache3[:, 1],
out=out_hidden_states,
).squeeze(dim=1)
else:
# According to micro benchmark results, torch.compile can get better performance for small token.
if M <= 32:
moe_sum_reduce_torch_compile(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states,
routed_scaling_factor,
)
else:
moe_sum_reduce_triton(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states,
routed_scaling_factor,
)
elif _is_hip:
if _use_aiter:
moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states,
)
else:
vllm_ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states,
)
else:
vllm_ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states,
)
return TritonRunnerOutput(
hidden_states=out_hidden_states,
)
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.TRITON
@register_fused_func("none", "triton")
def fused_experts_none_to_triton(
dispatch_output: StandardDispatchOutput,
quant_info: TritonMoeQuantInfo,
runner_config: MoeRunnerConfig,
) -> StandardCombineInput:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
output = fused_experts(
hidden_states=dispatch_output.hidden_states,
w1=quant_info.w13_weight,
w2=quant_info.w2_weight,
topk_output=dispatch_output.topk_output,
moe_runner_config=runner_config,
b1=quant_info.b13,
b2=quant_info.b2,
use_fp8_w8a8=quant_info.use_fp8_w8a8,
use_int8_w8a8=quant_info.use_int8_w8a8,
use_int8_w8a16=quant_info.use_int8_w8a16,
use_int4_w4a16=quant_info.use_int4_w4a16,
per_channel_quant=quant_info.per_channel_quant,
w1_scale=quant_info.w13_scale,
w2_scale=quant_info.w2_scale,
w1_zp=quant_info.w13_zp,
w2_zp=quant_info.w2_zp,
a1_scale=quant_info.a13_scale,
a2_scale=quant_info.a2_scale,
block_shape=quant_info.block_shape,
)
return StandardCombineInput(
hidden_states=output,
)
@register_pre_permute("standard", "triton")
def pre_permute_standard_to_triton(
dispatch_output: StandardDispatchOutput,
quant_info: TritonMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> TritonRunnerInput:
# NOTE: this is dead code as a fused func for standard format is registered.
# This is left here for testing and examples.
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
get_config_dtype_str,
moe_align_block_size,
try_get_optimal_moe_config,
)
from sglang.srt.layers.moe.topk import TopKOutputChecker
hidden_states, topk_output = dispatch_output
assert TopKOutputChecker.format_is_standard(topk_output)
num_tokens = hidden_states.shape[0]
num_local_experts = runner_config.num_local_experts
if (
not (quant_info.use_fp8_w8a8 or quant_info.use_int8_w8a8)
or quant_info.block_shape is not None
or _use_aiter
):
padding_size = 0
else:
padding_size = _MOE_PADDING_SIZE
config_dtype = get_config_dtype_str(
use_fp8_w8a8=quant_info.use_fp8_w8a8,
use_int8_w8a8=quant_info.use_int8_w8a8,
use_int8_w8a16=quant_info.use_int8_w8a16,
use_int4_w4a16=quant_info.use_int4_w4a16,
dtype=hidden_states.dtype,
)
get_config_func = functools.partial(
try_get_optimal_moe_config,
quant_info.w13_weight.shape,
(
num_local_experts,
quant_info.w2_weight.shape[1],
quant_info.w2_weight.shape[2] - padding_size,
),
topk_output.topk_ids.shape[1],
config_dtype,
block_shape=quant_info.block_shape,
)
config = get_config_func(num_tokens)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_output.topk_ids, config["BLOCK_SIZE_M"], num_local_experts
)
running_state["config"] = config
return TritonRunnerInput(
hidden_states=hidden_states,
topk_weights=topk_output.topk_weights,
topk_ids=topk_output.topk_ids,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,
)
@register_post_permute("triton", "standard")
def post_permute_triton_to_standard(
runner_output: TritonRunnerOutput,
quant_info: TritonMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> StandardCombineInput:
# NOTE: this is dead code as a fused func for standard format is registered.
# This is left here for testing and examples.
return StandardCombineInput(
hidden_states=runner_output.hidden_states,
)
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher,
BaseDispatcherConfig,
CombineInput,
CombineInputChecker,
CombineInputFormat,
DispatchOutput,
DispatchOutputChecker,
DispatchOutputFormat,
......@@ -9,21 +12,32 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
AscendDeepEPLLOutput,
DeepEPConfig,
DeepEPDispatcher,
DeepEPLLCombineInput,
DeepEPLLOutput,
DeepEPNormalCombineInput,
DeepEPNormalOutput,
)
from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
)
__all__ = [
"AscendDeepEPLLOutput",
"BaseDispatcher",
"BaseDispatcherConfig",
"CombineInput",
"CombineInputChecker",
"CombineInputFormat",
"DispatchOutput",
"DispatchOutputFormat",
"DispatchOutputChecker",
"StandardDispatchOutput",
"StandardCombineInput",
"DeepEPConfig",
"DeepEPDispatcher",
"DeepEPNormalOutput",
"DeepEPLLOutput",
"DeepEPLLCombineInput",
"DeepEPNormalCombineInput",
]
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum, auto
from enum import Enum
from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable
import torch
......@@ -9,10 +9,16 @@ import torch
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
AscendDeepEPLLOutput,
DeepEPLLCombineInput,
DeepEPLLOutput,
DeepEPNormalCombineInput,
DeepEPNormalOutput,
StandardCombineInput,
StandardDispatchOutput,
)
from sglang.srt.layers.moe.topk import TopKOutput
# ------------------------------ Dispatch Output -------------------------------------
class DispatchOutputChecker:
......@@ -50,10 +56,10 @@ class DispatchOutputChecker:
class DispatchOutputFormat(Enum):
STANDARD = auto()
DEEPEP_NORMAL = auto()
DEEPEP_LL = auto()
ASCENT_LL = auto()
STANDARD = "standard"
DEEPEP_NORMAL = "deepep_normal"
DEEPEP_LL = "deepep_ll"
ASCENT_LL = "ascent_ll"
def is_standard(self) -> bool:
return self == DispatchOutputFormat.STANDARD
......@@ -78,10 +84,63 @@ class DispatchOutputFormat(Enum):
class DispatchOutput(Protocol):
"""Protocol for dispatch outputs in different formats."""
# TODO: add hidden_states to the protocol
@property
def format(self) -> DispatchOutputFormat: ...
# ------------------------------ Combine Input -------------------------------------
class CombineInputChecker:
@staticmethod
def format_is_standard(
combine_input: CombineInput,
) -> TypeGuard[StandardCombineInput]:
return combine_input.format == CombineInputFormat.STANDARD
@staticmethod
def format_is_deepep_normal(
combine_input: CombineInput,
) -> TypeGuard[DeepEPNormalCombineInput]:
return combine_input.format == CombineInputFormat.DEEPEP_NORMAL
@staticmethod
def format_is_deepep_ll(
combine_input: CombineInput,
) -> TypeGuard[DeepEPLLCombineInput]:
return combine_input.format == CombineInputFormat.DEEPEP_LL
@staticmethod
def format_is_deepep(
combine_input: CombineInput,
) -> TypeGuard[Union[DeepEPNormalCombineInput, DeepEPLLCombineInput]]:
return combine_input.format in [
CombineInputFormat.DEEPEP_NORMAL,
CombineInputFormat.DEEPEP_LL,
]
class CombineInputFormat(Enum):
STANDARD = "standard"
DEEPEP_NORMAL = "deepep_normal"
DEEPEP_LL = "deepep_ll"
@runtime_checkable
class CombineInput(Protocol):
"""Protocol for combine inputs in different formats."""
# TODO: add hidden_states to the protocol
@property
def format(self) -> CombineInputFormat: ...
# ------------------------------ Base Dispatcher -------------------------------------
class BaseDispatcherConfig(ABC):
"""Base class for dispatcher configs."""
......@@ -92,9 +151,11 @@ class BaseDispatcher(ABC):
"""Base class for dispatchers."""
@abstractmethod
def dispatch(self, *args, **kwargs) -> DispatchOutput:
def dispatch(
self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs
) -> DispatchOutput:
pass
@abstractmethod
def combine(self, *args, **kwargs) -> torch.Tensor:
def combine(self, combine_input: CombineInput, **kwargs) -> torch.Tensor:
pass
......@@ -5,13 +5,15 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher,
BaseDispatcherConfig,
CombineInput,
CombineInputFormat,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.utils import (
get_bool_env_var,
......@@ -56,6 +58,7 @@ class DeepEPNormalOutput(NamedTuple):
"""DeepEP normal dispatch output."""
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
# hidden_states_scale
topk_idx: torch.Tensor
topk_weights: torch.Tensor
num_recv_tokens_per_expert: List[int]
......@@ -99,6 +102,30 @@ assert isinstance(DeepEPLLOutput, DispatchOutput)
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
class DeepEPNormalCombineInput(NamedTuple):
"""DeepEP normal combine input."""
pass
@property
def format(self) -> CombineInputFormat:
return CombineInputFormat.DEEPEP_NORMAL
class DeepEPLLCombineInput(NamedTuple):
"""DeepEP low latency combine input."""
pass
@property
def format(self) -> CombineInputFormat:
return CombineInputFormat.DEEPEP_LL
assert isinstance(DeepEPNormalCombineInput, CombineInput)
assert isinstance(DeepEPLLCombineInput, CombineInput)
class DeepEPDispatchMode(IntEnum):
NORMAL = auto()
LOW_LATENCY = auto()
......
from __future__ import annotations
from typing import NamedTuple
from typing import TYPE_CHECKING, NamedTuple
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
import torch
from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher,
CombineInput,
CombineInputFormat,
DispatchOutput,
DispatchOutputFormat,
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
class StandardDispatchOutput(NamedTuple):
"""Standard dispatch output."""
hidden_states: torch.Tensor
topk_output: TopKOutput
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.STANDARD
assert isinstance(StandardDispatchOutput, DispatchOutput)
class StandardCombineInput(NamedTuple):
"""Standard combine input."""
hidden_states: torch.Tensor
@property
def format(self) -> CombineInputFormat:
return CombineInputFormat.STANDARD
assert isinstance(StandardCombineInput, CombineInput)
class StandardDispatcher(BaseDispatcher):
def dispatch(
self, hidden_states: torch.Tensor, topk_output: TopKOutput
) -> DispatchOutput:
return StandardDispatchOutput(
hidden_states=hidden_states, topk_output=topk_output
)
def combine(self, combine_input: CombineInput) -> torch.Tensor:
if isinstance(combine_input, StandardCombineInput):
return combine_input.hidden_states
else:
# TODO: this branch should be removed in the future
assert isinstance(combine_input, torch.Tensor)
return combine_input
from __future__ import annotations
import importlib.util
import logging
from enum import Enum
from functools import lru_cache
from typing import TYPE_CHECKING, Optional
......@@ -12,11 +13,12 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.utils import logger
if TYPE_CHECKING:
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
class MoeA2ABackend(Enum):
......@@ -131,7 +133,7 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
global MOE_A2A_BACKEND
if MOE_A2A_BACKEND is None:
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
MOE_A2A_BACKEND = MoeA2ABackend(None)
MOE_A2A_BACKEND = MoeA2ABackend.NONE
return MOE_A2A_BACKEND
......@@ -139,7 +141,7 @@ def get_moe_runner_backend() -> MoeRunnerBackend:
global MOE_RUNNER_BACKEND
if MOE_RUNNER_BACKEND is None:
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
return MOE_RUNNER_BACKEND
......@@ -147,7 +149,7 @@ def get_deepep_mode() -> DeepEPMode:
global DEEPEP_MODE
if DEEPEP_MODE is None:
logger.warning("DEEPEP_MODE is not initialized, using auto mode")
DEEPEP_MODE = DeepEPMode("auto")
DEEPEP_MODE = DeepEPMode.AUTO
return DEEPEP_MODE
......
......@@ -34,7 +34,10 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
StandardDispatchOutput,
CombineInput,
)
from sglang.srt.utils import is_cuda, is_hip
......@@ -736,24 +739,32 @@ class AWQMoEMethod(FusedMoEMethodBase):
)
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
assert (
moe_runner_config.activation == "silu"
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# The input must currently be float16
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids, router_logits = topk_output
return fused_marlin_moe(
output = fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
......@@ -768,3 +779,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
w2_zeros=layer.w2_qzeros,
num_bits=self.quant_config.weight_bits,
).to(orig_dtype)
return StandardCombineInput(hidden_states=output)
......@@ -3,6 +3,7 @@ from __future__ import annotations
import inspect
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
import torch
......@@ -10,7 +11,7 @@ from torch import nn
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
class QuantizeMethodBase(ABC):
......@@ -89,20 +90,24 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError
@abstractmethod
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
raise NotImplementedError
@abstractmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: DispatchOutput,
) -> CombineInput:
raise NotImplementedError
......
......@@ -9,6 +9,8 @@ import torch
from torch.nn import Module
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
......@@ -22,8 +24,10 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
ACTIVATION_SCHEMES = ["static", "dynamic"]
......@@ -257,7 +261,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
......@@ -273,25 +277,28 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size % block_n != 0:
if intermediate_size_per_partition % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1:
# Required by row parallel
if intermediate_size % block_k != 0:
if intermediate_size_per_partition % block_k != 0:
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
......@@ -300,7 +307,10 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
......@@ -311,7 +321,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size + block_n - 1) // block_n),
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
......@@ -321,7 +331,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
torch.ones(
num_experts,
(hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k,
(intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
......@@ -344,26 +354,27 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
# Block quant doesn't need to process weights after loading
return
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
# Expert fusion with INT8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale_inv),
w2_scale=(layer.w2_weight_scale_inv),
a1_scale=layer.w13_input_scale,
w13_scale=layer.w13_weight_scale_inv,
w2_scale=layer.w2_weight_scale_inv,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
return self.runner.run(dispatch_output, quant_info)
......@@ -11,6 +11,8 @@ import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
......@@ -30,8 +32,10 @@ from sglang.srt.utils import (
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
......@@ -293,14 +297,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
)
torch.cuda.empty_cache()
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
if (
_use_aiter
......@@ -308,7 +322,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
and moe_runner_config.apply_router_weight_on_input
):
topk_weights, topk_ids, _ = topk_output
return rocm_fused_experts_tkw1(
output = rocm_fused_experts_tkw1(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
......@@ -324,21 +338,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
return StandardCombineInput(hidden_states=output)
else:
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy
== QuantizationStrategy.CHANNEL,
w1_scale=layer.w13_weight_scale,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
return self.runner.run(dispatch_output, quant_info)
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
......@@ -380,8 +393,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
params_dtype == torch.float16
), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
......@@ -415,13 +426,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
# In the case where we have actorder/g_idx,
# we do not partition the w2 scales
load_full_w2 = self.actorder and self.group_size != -1
w2_scales_size = (
intermediate_size_full if load_full_w2 else intermediate_size_per_partition
)
self.is_k_full = (not self.actorder) or (
intermediate_size_per_partition == intermediate_size_full
)
if load_full_w2:
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
else:
w2_scales_size = intermediate_size_per_partition
self.is_k_full = (not self.actorder) or layer.moe_tp_size == 1
if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
......@@ -640,21 +651,29 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
)
replace_tensor("w2_weight_scale", marlin_w2_scales)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
assert (
moe_runner_config.activation == "silu"
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, router_logits = topk_output
return torch.ops.vllm.fused_marlin_moe(
output = torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
......@@ -670,3 +689,4 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
num_bits=self.num_bits,
is_k_full=self.is_k_full,
)
return StandardCombineInput(hidden_states=output)
......@@ -30,6 +30,9 @@ except ImportError:
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
from sglang.srt.layers.parameter import (
BlockQuantScaleParameter,
ModelWeightParameter,
......@@ -81,7 +84,11 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
DispatchOutput,
StandardDispatchOutput,
)
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
......@@ -527,7 +534,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
......@@ -543,18 +550,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size % block_n != 0:
if intermediate_size_per_partition % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1:
# Required by row parallel
if intermediate_size % block_k != 0:
if intermediate_size_per_partition % block_k != 0:
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}."
)
......@@ -564,7 +571,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size,
2 * intermediate_size_per_partition,
hidden_size // 8,
dtype=params_dtype,
),
......@@ -572,20 +579,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype
num_experts,
hidden_size,
intermediate_size_per_partition // 8,
dtype=params_dtype,
),
requires_grad=False,
)
else:
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
......@@ -601,7 +617,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size + block_n - 1) // block_n),
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
......@@ -611,7 +627,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
torch.ones(
num_experts,
(hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k,
(intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
......@@ -632,19 +648,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
self.c_strides1 = torch.full(
(num_experts,),
2 * intermediate_size,
2 * intermediate_size_per_partition,
device=w13_weight.device,
dtype=torch.int64,
)
self.ab_strides2 = torch.full(
(num_experts,),
intermediate_size,
intermediate_size_per_partition,
device=w2_weight.device,
dtype=torch.int64,
)
self.c_strides2 = torch.full(
(num_experts,),
hidden_size,
intermediate_size_per_partition,
device=w2_weight.device,
dtype=torch.int64,
)
......@@ -691,7 +707,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if _is_hip: # _use_aiter: TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale1 = torch.nn.Parameter(
......@@ -984,14 +1004,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
torch.cuda.empty_cache()
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
dispatch_output: DispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
......@@ -1001,7 +1030,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
output = torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -1017,6 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
None, # a2_scale
True, # is_vnni
)
return StandardCombineInput(hidden_states=output)
if _is_hip:
ret = self.maybe_apply_hip_fused_experts(
......@@ -1027,7 +1057,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
moe_runner_config.no_combine,
)
if ret is not None:
return ret
return StandardCombineInput(hidden_states=ret)
if self.use_cutlass_fused_experts_fp8:
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
......@@ -1056,17 +1086,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.problem_sizes2,
use_fp8_blockscale=True,
)
# Scale by routed_scaling_factor is fused into select_experts.
return output
# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
return StandardCombineInput(hidden_states=output)
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
w1_scale=(
w13_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
......@@ -1074,20 +1100,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
return self.runner.run(dispatch_output, quant_info)
def apply_with_router_logits(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
dispatch_output: StandardDispatchOutput,
) -> torch.Tensor:
activation = moe_runner_config.activation
routed_scaling_factor = moe_runner_config.routed_scaling_factor
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
activation = self.moe_runner_config.activation
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment