"src/diffusers/models/controlnets/controlnet_flax.py" did not exist on "4d1e4e24e54d00b2a1aff17410a9a86594ae8b8a"
Unverified Commit 29589512 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)

parent 584e1ab2
......@@ -11,6 +11,7 @@ import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
get_config_dtype_str,
......@@ -18,7 +19,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
get_default_config,
get_moe_configs,
)
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip
_is_hip = is_hip()
......@@ -117,17 +119,23 @@ def benchmark_config(
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_output = select_experts(x, input_gating, topk, renormalize=True)
topk_config = TopKConfig(
top_k=topk,
renormalize=True,
)
topk_output = select_experts(x, input_gating, topk_config)
def prepare(i: int):
input_gating = gating_output[i]
new_topk_output = select_experts(x, input_gating, topk, renormalize=True)
new_topk_output = select_experts(x, input_gating, topk_config)
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
topk_output.topk_ids.copy_(new_topk_output.topk_ids)
topk_output.router_logits.copy_(new_topk_output.router_logits)
def run():
from sglang.srt.layers.moe.fused_moe_triton import override_config
moe_runner_config = MoeRunnerConfig(
inplace=True,
)
with override_config(config):
fused_moe(
......@@ -135,7 +143,7 @@ def benchmark_config(
w1,
w2,
topk_output,
inplace=True,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
......
......@@ -213,12 +213,11 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--ep-size` | The expert parallelism size. | 1 |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | None |
| `--enable-flashinfer-cutlass-moe` | Enabling Flashinfer Cutlass MoE implementation for high throughput. | False |
| `--enable-flashinfer-trtllm-moe` | Enabling Flashinfer Trtllm MoE implementation for low latency. | False |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none |
| `--moe-runner-backend` | Select the runner backend for MoE. | 'triton' |
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in expert parallel. | None |
| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in EPLB. | None |
| `--init-expert-location` | Initial location of EP experts. | trivial |
| `--enable-eplb` | Enable EPLB algorithm. | False |
| `--eplb-algorithm` | Chosen EPLB algorithm. | auto |
......@@ -280,7 +279,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--disable-chunked-prefix-cache` | Disable chunked prefix cache. | False |
| `--disable-fast-image-processor` | Disable fast image processor. | False |
| `--enable-return-hidden-states` | Enable returning hidden states. | False |
| `--enable-triton-kernel-moe` | Enable Triton kernel for MoE. | False |
## Debug tensor dumps
......
......@@ -61,7 +61,6 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -300,11 +299,6 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None,
enable_two_batch_overlap=model_runner.server_args.enable_two_batch_overlap,
enable_deepep_moe=MoeA2ABackend(
model_runner.server_args.moe_a2a_backend
).is_deepep(),
deepep_mode=DeepEPMode(model_runner.server_args.deepep_mode),
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
)
......
......@@ -25,7 +25,6 @@ import torch
import torch.distributed
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var
......@@ -288,14 +287,14 @@ class _SinglePassGatherer(ABC):
)
if server_args.expert_distribution_recorder_mode == "stat_approx":
if server_args.moe_a2a_backend is not None and (
if server_args.moe_a2a_backend != "none" and (
server_args.deepep_mode == "normal"
):
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
else:
raise NotImplementedError
if server_args.moe_a2a_backend is not None:
if server_args.moe_a2a_backend != "none":
if server_args.deepep_mode == "normal":
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
elif server_args.deepep_mode == "low_latency":
......
......@@ -17,7 +17,7 @@ from enum import Enum, auto
from functools import partial
from typing import Dict, Optional
import torch.distributed
import torch
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
......@@ -35,6 +35,7 @@ from sglang.srt.layers.dp_attention import (
get_global_dp_buffer,
get_local_dp_buffer,
)
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -111,7 +112,7 @@ class LayerScatterModes:
if context.is_layer_sparse:
return (
ScatterMode.SCATTERED
if not global_server_args_dict["moe_a2a_backend"].is_standard()
if not get_moe_a2a_backend().is_none()
else ScatterMode.FULL
)
else:
......
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.utils import (
DeepEPMode,
MoeA2ABackend,
MoeRunnerBackend,
get_deepep_config,
get_deepep_mode,
get_moe_a2a_backend,
get_moe_runner_backend,
get_tbo_token_distribution_threshold,
initialize_moe_config,
is_tbo_enabled,
should_use_flashinfer_trtllm_moe,
)
__all__ = [
"DeepEPMode",
"MoeA2ABackend",
"MoeRunnerConfig",
"MoeRunnerBackend",
"initialize_moe_config",
"get_moe_a2a_backend",
"get_moe_runner_backend",
"get_deepep_mode",
"should_use_flashinfer_trtllm_moe",
"is_tbo_enabled",
"get_tbo_token_distribution_threshold",
"get_deepep_config",
]
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
import torch
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
......@@ -16,14 +22,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import (
Fp8Config,
Fp8MoEMethod,
get_tile_tokens_dim,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
sglang_per_token_group_quant_fp8,
......@@ -89,12 +90,11 @@ class EPMoE(FusedMoE):
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
with_bias: bool = False,
):
super().__init__(
......@@ -106,13 +106,12 @@ class EPMoE(FusedMoE):
top_k=top_k,
params_dtype=params_dtype,
quant_config=quant_config,
tp_size=tp_size,
prefix=prefix,
activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
with_bias=with_bias,
)
......@@ -163,7 +162,8 @@ class EPMoE(FusedMoE):
)
assert self.quant_method is not None
assert self.activation == "silu"
assert self.moe_runner_config.activation == "silu"
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
......@@ -327,8 +327,8 @@ class EPMoE(FusedMoE):
m_max * self.start_expert_id,
BLOCK_SIZE=512,
)
if self.routed_scaling_factor is not None:
output *= self.routed_scaling_factor
if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.moe_runner_config.routed_scaling_factor
return output
......@@ -349,11 +349,9 @@ class DeepEPMoE(EPMoE):
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
):
super().__init__(
num_experts=num_experts,
......@@ -364,12 +362,11 @@ class DeepEPMoE(EPMoE):
num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype,
quant_config=quant_config,
tp_size=tp_size,
prefix=prefix,
activation=activation,
routed_scaling_factor=routed_scaling_factor,
)
self.deepep_mode = deepep_mode
self.deepep_mode = get_deepep_mode()
# TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group
......@@ -383,7 +380,7 @@ class DeepEPMoE(EPMoE):
num_local_experts=self.num_local_experts,
hidden_size=hidden_size,
params_dtype=params_dtype,
deepep_mode=deepep_mode,
deepep_mode=self.deepep_mode,
async_finish=True, # TODO
return_recv_hook=True,
)
......@@ -458,15 +455,19 @@ class DeepEPMoE(EPMoE):
)
def moe_impl(self, dispatch_output: DispatchOutput):
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
if _use_aiter:
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(dispatch_output)
if _is_npu:
assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
return self.forward_npu(dispatch_output)
if dispatch_output.format.is_deepep_normal():
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output)
elif dispatch_output.format.is_deepep_ll():
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_masked(dispatch_output)
else:
......@@ -490,7 +491,7 @@ class DeepEPMoE(EPMoE):
def forward_aiter(
self,
dispatch_output: DeepEPNormalOutput,
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
):
hidden_states, topk_idx, topk_weights = (
dispatch_output.hidden_states,
......@@ -516,7 +517,7 @@ class DeepEPMoE(EPMoE):
quant_type=QuantType.per_128x128,
activation=(
ActivationType.Silu
if self.activation == "silu"
if self.moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
expert_mask=self.expert_mask,
......@@ -531,7 +532,7 @@ class DeepEPMoE(EPMoE):
)
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
assert self.quant_method is not None
assert self.activation == "silu"
assert self.moe_runner_config.activation == "silu"
if num_recv_tokens_per_expert is None:
return hidden_states_fp8.bfloat16()
all_tokens = sum(num_recv_tokens_per_expert)
......@@ -652,7 +653,7 @@ class DeepEPMoE(EPMoE):
):
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.activation == "silu"
assert self.moe_runner_config.activation == "silu"
# GroupGemm-0
num_groups, m, k = hidden_states_fp8[0].size()
......@@ -783,12 +784,12 @@ class DeepEPMoE(EPMoE):
def get_moe_impl_class():
if global_server_args_dict["moe_a2a_backend"].is_deepep():
if get_moe_a2a_backend().is_deepep():
return DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
if get_moe_runner_backend().is_flashinfer_trtllm():
try:
# Check the quantization argument directly
quantization = global_server_args_dict.get("quantization")
......@@ -803,7 +804,7 @@ def get_moe_impl_class():
if should_use_flashinfer_trtllm_moe():
return FlashInferFusedMoE
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE
if get_moe_expert_parallel_world_size() > 1:
return EPMoE
......
......@@ -3,28 +3,22 @@ Torch-native implementation for FusedMoE. This is used for torch.compile.
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
"""
from typing import Callable, Optional
import torch
from torch.nn import functional as F
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
def fused_moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
if apply_router_weight_on_input:
if moe_runner_config.apply_router_weight_on_input:
raise NotImplementedError()
topk_weights, topk_ids, _ = topk_output
......@@ -33,12 +27,12 @@ def fused_moe_forward_native(
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
if activation == "silu":
if moe_runner_config.activation == "silu":
x1 = F.silu(x1)
elif activation == "gelu":
elif moe_runner_config.activation == "gelu":
x1 = F.gelu(x1)
else:
raise ValueError(f"Unsupported activation: {activation=}")
raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
......@@ -47,16 +41,11 @@ def fused_moe_forward_native(
def moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
if apply_router_weight_on_input:
if moe_runner_config.apply_router_weight_on_input:
raise NotImplementedError()
topk_weights, topk_ids, _ = topk_output
......@@ -72,12 +61,12 @@ def moe_forward_native(
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
if activation == "silu":
if moe_runner_config.activation == "silu":
act = SiluAndMul()
elif activation == "gelu":
elif moe_runner_config.activation == "gelu":
act = GeluAndMul()
else:
raise ValueError(f"Unsupported activation: {activation=}")
raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
outputs = []
start_idx = 0
......
......@@ -2,17 +2,20 @@
"""Fused MoE kernel."""
from __future__ import annotations
import functools
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
scaled_fp8_quant,
......@@ -1025,8 +1028,8 @@ def inplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
) -> None:
fused_experts_impl(
hidden_states,
......@@ -1053,8 +1056,8 @@ def inplace_fused_experts(
block_shape,
False,
routed_scaling_factor,
activation_alpha,
swiglu_limit,
gemm1_alpha,
gemm1_limit,
)
......@@ -1081,8 +1084,8 @@ def inplace_fused_experts_fake(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
) -> None:
pass
......@@ -1119,8 +1122,8 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
......@@ -1147,8 +1150,8 @@ def outplace_fused_experts(
block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
gemm1_alpha=gemm1_alpha,
gemm1_limit=gemm1_limit,
)
......@@ -1176,8 +1179,8 @@ def outplace_fused_experts_fake(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -1194,12 +1197,10 @@ def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_output: TopKOutput,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
......@@ -1212,14 +1213,10 @@ def fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
):
topk_weights, topk_ids, _ = topk_output
if inplace:
assert not no_combine, "no combine + inplace makes no sense"
if moe_runner_config.inplace:
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts(
hidden_states,
w1,
......@@ -1228,8 +1225,8 @@ def fused_experts(
topk_ids,
b1,
b2,
activation,
apply_router_weight_on_input,
moe_runner_config.activation,
moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
......@@ -1242,9 +1239,9 @@ def fused_experts(
a1_scale,
a2_scale,
block_shape,
routed_scaling_factor,
activation_alpha,
swiglu_limit,
moe_runner_config.routed_scaling_factor,
moe_runner_config.gemm1_alpha,
moe_runner_config.gemm1_clamp_limit,
)
return hidden_states
else:
......@@ -1256,8 +1253,8 @@ def fused_experts(
topk_ids,
b1,
b2,
activation,
apply_router_weight_on_input,
moe_runner_config.activation,
moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
......@@ -1270,10 +1267,10 @@ def fused_experts(
a1_scale,
a2_scale,
block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
no_combine=moe_runner_config.no_combine,
routed_scaling_factor=moe_runner_config.routed_scaling_factor,
gemm1_alpha=moe_runner_config.gemm1_alpha,
gemm1_limit=moe_runner_config.gemm1_clamp_limit,
)
......@@ -1370,11 +1367,11 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
@torch.compile
def swiglu_with_alpha_and_limit(x, alpha, limit):
def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
return gate * torch.sigmoid(gate * alpha) * (up + 1)
gate = gate.clamp(min=None, max=gemm1_limit)
up = up.clamp(min=-gemm1_limit, max=gemm1_limit)
return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1)
def fused_experts_impl(
......@@ -1402,8 +1399,8 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
):
padded_size = padding_size
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
......@@ -1533,12 +1530,12 @@ def fused_experts_impl(
block_shape=block_shape,
)
if activation == "silu":
if activation_alpha is not None:
assert swiglu_limit is not None
if gemm1_alpha is not None:
assert gemm1_limit is not None
intermediate_cache2 = swiglu_with_alpha_and_limit(
intermediate_cache1.view(-1, N),
activation_alpha,
swiglu_limit,
gemm1_alpha,
gemm1_limit,
)
elif _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
......@@ -1547,10 +1544,8 @@ def fused_experts_impl(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
elif activation == "gelu":
assert (
activation_alpha is None
), "activation_alpha is not supported for gelu"
assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
if _is_cuda:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
......@@ -1641,12 +1636,10 @@ def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_output: TopKOutput,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig = MoeRunnerConfig(),
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
......@@ -1659,10 +1652,6 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
......@@ -1672,11 +1661,10 @@ def fused_moe(
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_output (TopKOutput): The top-k output of the experts.
- topk_output (StandardTopKOutput): The top-k output of the experts.
- moe_runner_config (MoeRunnerConfig): The configuration for the MoE runner.
- b1 (Optional[torch.Tensor]): Optional bias for w1.
- b2 (Optional[torch.Tensor]): Optional bias for w2.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
......@@ -1696,9 +1684,9 @@ def fused_moe(
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
- activation_alpha (Optional[float]): Optional alpha for the activation
- gemm1_alpha (Optional[float]): Optional gemm1_alpha for the activation
function.
- swiglu_limit (Optional[float]): Optional limit for the swiglu activation
- gemm1_limit (Optional[float]): Optional gemm1_limit for the swiglu activation
function.
Returns:
......@@ -1710,11 +1698,9 @@ def fused_moe(
w1,
w2,
topk_output,
moe_runner_config=moe_runner_config,
b1=b1,
b2=b2,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
......@@ -1727,8 +1713,4 @@ def fused_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import datetime
import glob
import logging
import os
import sys
from enum import Enum
from typing import List, Optional, Tuple
......@@ -22,8 +18,12 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.moe import (
MoeRunnerConfig,
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
......@@ -126,7 +126,6 @@ class FusedMoE(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
activation: str = "silu",
apply_router_weight_on_input: bool = False,
......@@ -134,9 +133,8 @@ class FusedMoE(torch.nn.Module):
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
use_weight_loader_fused: bool = False,
with_bias=False,
):
......@@ -153,9 +151,17 @@ class FusedMoE(torch.nn.Module):
self.expert_map_cpu = None
self.expert_map_gpu = None
# For activation
self.activation_alpha = activation_alpha
self.swiglu_limit = swiglu_limit
self.moe_runner_config = MoeRunnerConfig(
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
)
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.")
......@@ -184,20 +190,12 @@ class FusedMoE(torch.nn.Module):
* self.num_local_experts
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
self.routed_scaling_factor = routed_scaling_factor
assert intermediate_size % self.moe_tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.reduce_results = reduce_results
self.activation = activation
self.apply_router_weight_on_input = apply_router_weight_on_input
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
self.use_triton_kernels = (
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
)
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels
......@@ -207,14 +205,12 @@ class FusedMoE(torch.nn.Module):
assert self.quant_method is not None
self.quant_config = quant_config
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
"enable_flashinfer_mxfp4_moe", False
)
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
if (
self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
and self.use_enable_flashinfer_mxfp4_moe
and self.use_flashinfer_mxfp4_moe
):
hidden_size = round_up(hidden_size, 256)
self.quant_method.create_weights(
......@@ -794,7 +790,7 @@ class FusedMoE(torch.nn.Module):
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
)
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
origin_hidden_states_dim = hidden_states.shape[-1]
assert self.quant_method is not None
......@@ -803,40 +799,22 @@ class FusedMoE(torch.nn.Module):
# If we are in EP mode, we need to move the expert map to GPU.
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
if self.expert_map_gpu is not None and isinstance(
topk_output, StandardTopKOutput
):
topk_output = topk_output._replace(
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
)
if self.expert_map_gpu is not None:
if TopKOutputChecker.format_is_standard(topk_output):
topk_output = topk_output._replace(
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
)
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
raise NotImplementedError()
# Matrix multiply.
with use_symmetric_memory(get_tp_group()) as sm:
kwargs = {}
if self.activation_alpha is not None:
kwargs["activation_alpha"] = self.activation_alpha
if self.swiglu_limit is not None:
kwargs["swiglu_limit"] = self.swiglu_limit
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
topk_output=topk_output,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
**(
dict(
tp_rank=self.moe_tp_rank,
tp_size=self.moe_tp_size,
ep_rank=self.moe_ep_rank,
ep_size=self.moe_ep_size,
)
if self.quant_method.__class__.__name__
== "ModelOptNvFp4FusedMoEMethod"
else {}
),
**kwargs,
moe_runner_config=self.moe_runner_config,
)
sm.tag(final_hidden_states)
......@@ -944,24 +922,10 @@ class FusedMoE(torch.nn.Module):
class FlashInferFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs):
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
super().__init__(*args, **kwargs)
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert self.use_flashinfer_trtllm_moe
assert (
self.activation == "silu"
......@@ -974,20 +938,14 @@ class FlashInferFusedMoE(FusedMoE):
self.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
# TRTLLM mode expects (TopK_config, router_logits) tuple
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
assert TopKOutputChecker.format_is_bypassed(topk_output)
# Matrix multiply.
final_hidden_states = self.quant_method.apply_with_router_logits(
layer=self,
x=hidden_states,
router_logits=router_logits,
activation=self.activation,
routed_scaling_factor=self.routed_scaling_factor,
topk_output=topk_output,
moe_runner_config=self.moe_runner_config,
)
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
......@@ -1000,28 +958,8 @@ class FlashInferFP4MoE(FusedMoE):
"""FP4 TRTLLM MoE implementation using FlashInfer."""
def __init__(self, *args, **kwargs):
# Extract DeepSeek-specific parameters
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
# Extract additional TopK parameters that were previously extracted in forward
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
super().__init__(*args, **kwargs)
# Store DeepSeek parameters
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
# ---------------------------------------------------------------------
# Helper: quantize hidden states to FP4 each forward pass
# ---------------------------------------------------------------------
......@@ -1052,21 +990,17 @@ class FlashInferFP4MoE(FusedMoE):
return hs_fp4, hs_sf
def forward(self, hidden_states: torch.Tensor, topk_output):
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
"""Forward pass using FP4 TRTLLM kernel.
Args:
hidden_states: Input tensor
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
topk_output: TopKOutput object with Bypassed format
"""
assert TopKOutputChecker.format_is_bypassed(topk_output)
# TRTLLM mode expects (TopK_config, router_logits) tuple
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
......@@ -1074,7 +1008,7 @@ class FlashInferFP4MoE(FusedMoE):
result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=self.correction_bias.to(hidden_states.dtype),
routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
......@@ -1094,15 +1028,15 @@ class FlashInferFP4MoE(FusedMoE):
output1_scale_gate_scalar=self.g1_alphas.data,
output2_scale_scalar=self.g2_alphas.data,
num_experts=self.num_experts,
top_k=self.top_k,
n_group=self.num_expert_group,
topk_group=self.topk_group,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
local_num_experts=self.num_local_experts,
routed_scaling_factor=self.routed_scaling_factor,
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
tile_tokens_dim=_get_tile_tokens_dim(
hidden_states.shape[0], self.top_k, self.num_local_experts
hidden_states.shape[0], topk_config.top_k, self.num_local_experts
),
routing_method_type=RoutingMethodType.DeepSeekV3,
do_finalize=True,
......
......@@ -18,6 +18,7 @@ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from triton_kernels.swiglu import swiglu_fn
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
......@@ -55,8 +56,7 @@ def triton_kernel_moe_forward(
w1: torch.Tensor,
w2: torch.Tensor,
topk_output: TopKOutput,
inplace: bool = False,
activation: str = "silu",
moe_runner_config: MoeRunnerConfig,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
......@@ -69,7 +69,10 @@ def triton_kernel_moe_forward(
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
assert topk_output.format.is_triton_kernel()
from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
routing_data, gather_idx, scatter_idx = topk_output
return triton_kernel_fused_experts(
......@@ -79,8 +82,8 @@ def triton_kernel_moe_forward(
routing_data,
gather_idx,
scatter_idx,
inplace=inplace,
activation=activation,
inplace=False, # triton kernel doesn't support inplace
activation=moe_runner_config.activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
......@@ -192,8 +195,7 @@ def triton_kernel_moe_with_bias_forward(
w2_pcg,
b2: torch.Tensor,
topk_output: TopKOutput,
inplace: bool = False,
activation: str = "silu",
moe_runner_config: MoeRunnerConfig,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
......@@ -203,10 +205,11 @@ def triton_kernel_moe_with_bias_forward(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[int] = None,
) -> torch.Tensor:
assert topk_output.format.is_triton_kernel()
from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
routing_data, gather_idx, scatter_idx = topk_output
return triton_kernel_fused_experts_with_bias(
......@@ -220,8 +223,8 @@ def triton_kernel_moe_with_bias_forward(
routing_data=routing_data,
gather_indx=gather_idx,
scatter_indx=scatter_idx,
inplace=inplace,
activation=activation,
inplace=False, # triton kernel doesn't support inplace
activation=moe_runner_config.activation,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
......@@ -231,8 +234,8 @@ def triton_kernel_moe_with_bias_forward(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
gemm1_alpha=moe_runner_config.gemm1_alpha,
gemm1_clamp_limit=moe_runner_config.gemm1_clamp_limit,
)
......@@ -258,10 +261,9 @@ def triton_kernel_fused_experts_with_bias(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[int] = None,
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
) -> torch.Tensor:
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
assert per_channel_quant == False, "per_channel_quant is not supported"
assert expert_map == None, "expert_map is not supported"
......@@ -307,7 +309,7 @@ def triton_kernel_fused_experts_with_bias(
act = FusedActivation(
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
(activation_alpha, swiglu_limit),
(gemm1_alpha, gemm1_clamp_limit),
2,
)
......
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
__all__ = ["MoeRunnerConfig"]
from dataclasses import dataclass
from typing import Optional
@dataclass
class MoeRunnerConfig:
activation: str = "silu"
apply_router_weight_on_input: bool = False
inplace: bool = True
no_combine: bool = False
routed_scaling_factor: Optional[float] = None
gemm1_alpha: Optional[float] = None
gemm1_clamp_limit: Optional[float] = None
......@@ -2,20 +2,26 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher,
BaseDispatcherConfig,
DispatchOutput,
DispatchOutputChecker,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.token_dispatcher.deepep import (
AscendDeepEPLLOutput,
DeepEPConfig,
DeepEPDispatcher,
DeepEPLLOutput,
DeepEPNormalOutput,
)
from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput
__all__ = [
"AscendDeepEPLLOutput",
"BaseDispatcher",
"BaseDispatcherConfig",
"DispatchOutput",
"DispatchOutputFormat",
"DispatchOutputChecker",
"StandardDispatchOutput",
"DeepEPConfig",
"DeepEPDispatcher",
"DeepEPNormalOutput",
......
......@@ -2,35 +2,76 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import Protocol, runtime_checkable
from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable
import torch
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
AscendDeepEPLLOutput,
DeepEPLLOutput,
DeepEPNormalOutput,
StandardDispatchOutput,
)
class MoEA2ABackend(Enum):
none = "none"
deepep = "deepep"
def is_none(self):
return self == MoEA2ABackend.none
class DispatchOutputChecker:
def is_deepep(self):
return self == MoEA2ABackend.deepep
@staticmethod
def format_is_standard(
dispatch_output: DispatchOutput,
) -> TypeGuard[StandardDispatchOutput]:
return dispatch_output.format.is_standard()
@staticmethod
def format_is_deepep_normal(
dispatch_output: DispatchOutput,
) -> TypeGuard[DeepEPNormalOutput]:
return dispatch_output.format.is_deepep_normal()
@staticmethod
def format_is_deepep_ll(
dispatch_output: DispatchOutput,
) -> TypeGuard[DeepEPLLOutput]:
return dispatch_output.format.is_deepep_ll()
@staticmethod
def format_is_deepep(
dispatch_output: DispatchOutput,
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
return dispatch_output.format.is_deepep()
@staticmethod
def format_is_ascent_ll(
dispatch_output: DispatchOutput,
) -> TypeGuard[AscendDeepEPLLOutput]:
return dispatch_output.format.is_ascent_ll()
class DispatchOutputFormat(Enum):
standard = auto()
deepep_normal = auto()
deepep_ll = auto()
STANDARD = auto()
DEEPEP_NORMAL = auto()
DEEPEP_LL = auto()
ASCENT_LL = auto()
def is_standard(self) -> bool:
return self == DispatchOutputFormat.standard
return self == DispatchOutputFormat.STANDARD
def is_deepep_normal(self) -> bool:
return self == DispatchOutputFormat.deepep_normal
return self == DispatchOutputFormat.DEEPEP_NORMAL
def is_deepep_ll(self) -> bool:
return self == DispatchOutputFormat.deepep_ll
return self == DispatchOutputFormat.DEEPEP_LL
def is_deepep(self) -> bool:
return self in [
DispatchOutputFormat.DEEPEP_NORMAL,
DispatchOutputFormat.DEEPEP_LL,
]
def is_ascent_ll(self) -> bool:
return self == DispatchOutputFormat.ASCENT_LL
@runtime_checkable
......
......@@ -2,27 +2,17 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
List,
NamedTuple,
Optional,
Protocol,
Tuple,
Union,
runtime_checkable,
)
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher,
BaseDispatcherConfig,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
get_bool_env_var,
get_int_env_var,
......@@ -72,7 +62,7 @@ class DeepEPNormalOutput(NamedTuple):
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_normal
return DispatchOutputFormat.DEEPEP_NORMAL
class DeepEPLLOutput(NamedTuple):
......@@ -86,7 +76,7 @@ class DeepEPLLOutput(NamedTuple):
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll
return DispatchOutputFormat.DEEPEP_LL
class AscendDeepEPLLOutput(NamedTuple):
......@@ -101,7 +91,7 @@ class AscendDeepEPLLOutput(NamedTuple):
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll
return DispatchOutputFormat.ASCENT_LL
assert isinstance(DeepEPNormalOutput, DispatchOutput)
......@@ -128,8 +118,8 @@ class DeepEPBuffer:
hidden_size: int,
param_bytes: int,
deepep_mode: DeepEPMode,
num_max_dispatch_tokens_per_rank: int = None,
num_experts: int = None,
num_max_dispatch_tokens_per_rank: int = -1,
num_experts: int = -1,
):
if cls._buffer is not None:
return cls._buffer
......@@ -156,8 +146,8 @@ class DeepEPBuffer:
num_rdma_bytes,
)
if deepep_mode.enable_low_latency():
assert num_max_dispatch_tokens_per_rank is not None
assert num_experts is not None and num_experts % group.size() == 0
assert num_max_dispatch_tokens_per_rank != -1
assert num_experts != -1 and num_experts % group.size() == 0
num_rdma_bytes = max(
Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank,
......@@ -181,7 +171,7 @@ class DeepEPBuffer:
).multi_processor_count
if (
(deepep_mode != DeepEPMode.LOW_LATENCY)
and not global_server_args_dict["enable_two_batch_overlap"]
and not is_tbo_enabled()
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
):
logger.warning(
......@@ -226,7 +216,7 @@ class DeepEPConfig(BaseDispatcherConfig):
_instance = None
def __init__(self):
config_str = global_server_args_dict["deepep_config"]
config_str = get_deepep_config()
if config_str:
config_parsed = load_json_config(config_str)
if torch.distributed.get_rank() == 0:
......
......@@ -13,7 +13,7 @@ class StandardDispatchOutput(NamedTuple):
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.standard
return DispatchOutputFormat.STANDARD
assert isinstance(StandardDispatchOutput, DispatchOutput)
......@@ -14,9 +14,18 @@
from __future__ import annotations
import logging
import math
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
from typing import (
Callable,
NamedTuple,
Optional,
Protocol,
TypeGuard,
runtime_checkable,
)
import torch
import torch.nn.functional as F
......@@ -28,7 +37,10 @@ from sglang.srt.eplb.expert_location_dispatch import (
ExpertLocationDispatchInfo,
topk_ids_logical_to_physical,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.layers.moe import (
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
......@@ -43,6 +55,7 @@ try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
except ImportError:
pass
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
......@@ -65,13 +78,48 @@ if _use_aiter:
if _is_npu:
import torch_npu
# -------------------------------- TopKConfig ---------------------------------------
@dataclass
class TopKConfig:
top_k: int
use_grouped_topk: bool = False
topk_group: int = 0
num_expert_group: int = 0
renormalize: bool = True
num_fused_shared_experts: int = 0
custom_routing_function: Optional[Callable] = None
correction_bias: Optional[torch.Tensor] = None
torch_native: bool = False
routed_scaling_factor: Optional[float] = None
apply_routed_scaling_factor_on_output: bool = False
# -------------------------------- TopKOutput ---------------------------------------
class TopKOutputChecker:
@staticmethod
def format_is_standard(topk_output: TopKOutput) -> TypeGuard[StandardTopKOutput]:
return topk_output.format.is_standard()
@staticmethod
def format_is_triton_kernel(
topk_output: TopKOutput,
) -> TypeGuard[TritonKernelTopKOutput]:
return topk_output.format.is_triton_kernel()
@staticmethod
def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
return topk_output.format.is_bypassed()
class TopKOutputFormat(Enum):
STANDARD = auto()
TRITON_KERNEL = auto()
BYPASSED = auto()
def is_standard(self) -> bool:
return self == TopKOutputFormat.STANDARD
......@@ -79,6 +127,9 @@ class TopKOutputFormat(Enum):
def is_triton_kernel(self) -> bool:
return self == TopKOutputFormat.TRITON_KERNEL
def is_bypassed(self) -> bool:
return self == TopKOutputFormat.BYPASSED
@runtime_checkable
class TopKOutput(Protocol):
......@@ -114,6 +165,20 @@ class TritonKernelTopKOutput(NamedTuple):
return TopKOutputFormat.TRITON_KERNEL
class BypassedTopKOutput(NamedTuple):
"""Bypassed top-k output format."""
hidden_states: torch.Tensor
router_logits: torch.Tensor
topk_config: TopKConfig
num_token_non_padded: Optional[torch.Tensor] = None
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.BYPASSED
# -------------------------------- TopK ---------------------------------------
......@@ -124,8 +189,8 @@ class TopK(CustomOp):
top_k: int,
*,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int = 0,
num_expert_group: int = 0,
renormalize: bool = True,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
......@@ -136,19 +201,23 @@ class TopK(CustomOp):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
super().__init__()
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.top_k = top_k
self.use_grouped_topk = use_grouped_topk
self.renormalize = renormalize
self.topk_group = topk_group
self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
self.topk_config = TopKConfig(
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
def forward_native(
self,
......@@ -158,20 +227,11 @@ class TopK(CustomOp):
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
torch_native = True
self.topk_config.torch_native = True
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
......@@ -187,24 +247,28 @@ class TopK(CustomOp):
if self.use_triton_kernels:
# renormalize=True is equivalent to sm_first=False
routing_data, gather_idx, scatter_idx = routing(
router_logits, self.top_k, sm_first=not self.renormalize
router_logits,
self.topk_config.top_k,
sm_first=not self.topk_config.renormalize,
)
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
elif (
should_use_flashinfer_trtllm_moe()
or get_moe_runner_backend().is_flashinfer_mxfp4()
):
return BypassedTopKOutput(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
else:
torch_native = False
self.topk_config.torch_native = False
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
......@@ -220,15 +284,7 @@ class TopK(CustomOp):
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
......@@ -244,35 +300,29 @@ class TopK(CustomOp):
global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
if global_num_experts == 256 and self.topk_config.renormalize is False:
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
router_logits = router_logits.to(torch.float32)
return torch_npu.npu_moe_gating_top_k(
router_logits,
k=self.top_k,
bias=self.correction_bias.to(torch.float32),
k_group=self.topk_group,
group_count=self.num_expert_group,
k=self.topk_config.top_k,
bias=self.topk_config.correction_bias.to(torch.float32),
k_group=self.topk_config.topk_group,
group_count=self.topk_config.num_expert_group,
group_select_mode=1,
renorm=0,
norm_type=1,
routed_scaling_factor=1,
routed_scaling_factor=routed_scaling_factor,
eps=float(1e-20),
)
else:
torch_native = True
self.topk_config.torch_native = True
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
......@@ -670,20 +720,23 @@ else:
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
topk_config: TopKConfig,
*,
use_grouped_topk: bool = False,
renormalize: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
) -> StandardTopKOutput:
top_k = topk_config.top_k
use_grouped_topk = topk_config.use_grouped_topk
topk_group = topk_config.topk_group
num_expert_group = topk_config.num_expert_group
renormalize = topk_config.renormalize
num_fused_shared_experts = topk_config.num_fused_shared_experts
custom_routing_function = topk_config.custom_routing_function
correction_bias = topk_config.correction_bias
torch_native = topk_config.torch_native
routed_scaling_factor = topk_config.routed_scaling_factor
router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs(
router_logits=router_logits,
......
from __future__ import annotations
import importlib.util
from enum import Enum
from functools import lru_cache
from typing import TYPE_CHECKING, Optional
from packaging import version as pkg_version
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import logger
@lru_cache(maxsize=1)
def should_use_flashinfer_trtllm_moe():
result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
return result
if TYPE_CHECKING:
from sglang.srt.server_args import ServerArgs
class MoeA2ABackend(Enum):
STANDARD = ("standard", "none")
NONE = "none"
DEEPEP = "deepep"
@classmethod
def _missing_(cls, value):
if value is None:
return cls.STANDARD
return cls.NONE
for member in cls:
if value in member.value:
if value == member.value:
return member
raise ValueError(f"No {cls.__name__} member for value {value}")
def is_none(self):
return self == MoeA2ABackend.NONE
def is_deepep(self):
return self == MoeA2ABackend.DEEPEP
def is_standard(self):
return self == MoeA2ABackend.STANDARD
class MoeRunnerBackend(Enum):
AUTO = "auto"
TRITON = "triton"
TRITON_KERNEL = "triton_kernel"
FLASHINFER = "flashinfer_trtllm"
FLASHINFER_CUTLASS = "flashinfer_cutlass"
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
def is_auto(self):
return self == MoeRunnerBackend.AUTO
def is_triton(self):
return self == MoeRunnerBackend.TRITON
def is_triton_kernel(self):
return self == MoeRunnerBackend.TRITON_KERNEL
def is_flashinfer_trtllm(self):
return self == MoeRunnerBackend.FLASHINFER
def is_flashinfer_cutlass(self):
return self == MoeRunnerBackend.FLASHINFER_CUTLASS
def is_flashinfer_mxfp4(self):
return self == MoeRunnerBackend.FLASHINFER_MXFP4
class DeepEPMode(Enum):
NORMAL = "normal"
LOW_LATENCY = "low_latency"
AUTO = "auto"
def enable_normal(self):
def enable_normal(self) -> bool:
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
def enable_low_latency(self):
def enable_low_latency(self) -> bool:
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
def resolve(self, is_extend_in_batch: bool):
def resolve(self, is_extend_in_batch: bool) -> DeepEPMode:
if self != DeepEPMode.AUTO:
return self
......@@ -57,3 +82,96 @@ class DeepEPMode(Enum):
return DeepEPMode.NORMAL
else:
return DeepEPMode.LOW_LATENCY
def is_normal(self) -> bool:
return self == DeepEPMode.NORMAL
def is_low_latency(self) -> bool:
return self == DeepEPMode.LOW_LATENCY
def is_auto(self) -> bool:
return self == DeepEPMode.AUTO
MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED: Optional[bool] = None
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
DEEPEP_CONFIG: Optional[str] = None
def initialize_moe_config(server_args: ServerArgs):
global MOE_A2A_BACKEND
global MOE_RUNNER_BACKEND
global DEEPEP_MODE
global DEEPEP_CONFIG
global IS_TBO_ENABLED
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
DEEPEP_CONFIG = server_args.deepep_config or ""
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
def get_moe_a2a_backend() -> MoeA2ABackend:
global MOE_A2A_BACKEND
if MOE_A2A_BACKEND is None:
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
MOE_A2A_BACKEND = MoeA2ABackend(None)
return MOE_A2A_BACKEND
def get_moe_runner_backend() -> MoeRunnerBackend:
global MOE_RUNNER_BACKEND
if MOE_RUNNER_BACKEND is None:
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
return MOE_RUNNER_BACKEND
def get_deepep_mode() -> DeepEPMode:
global DEEPEP_MODE
if DEEPEP_MODE is None:
logger.warning("DEEPEP_MODE is not initialized, using auto mode")
DEEPEP_MODE = DeepEPMode("auto")
return DEEPEP_MODE
def get_deepep_config() -> str:
global DEEPEP_CONFIG
if DEEPEP_CONFIG is None:
logger.warning("DEEPEP_CONFIG is not initialized, using default config")
DEEPEP_CONFIG = ""
return DEEPEP_CONFIG
def is_tbo_enabled() -> bool:
global IS_TBO_ENABLED
if IS_TBO_ENABLED is None:
logger.warning("IS_TBO_ENABLED is not initialized, using False")
IS_TBO_ENABLED = False
return IS_TBO_ENABLED
def get_tbo_token_distribution_threshold() -> float:
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
logger.warning(
"TBO_TOKEN_DISTRIBUTION_THRESHOLD is not initialized, using 0.48"
)
TBO_TOKEN_DISTRIBUTION_THRESHOLD = 0.48
return TBO_TOKEN_DISTRIBUTION_THRESHOLD
@lru_cache(maxsize=1)
def should_use_flashinfer_trtllm_moe():
result = get_moe_runner_backend().is_flashinfer_trtllm() and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
return result
......@@ -33,7 +33,8 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.utils import is_cuda, is_hip
......@@ -739,13 +740,12 @@ class AWQMoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
**kwargs,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# The input must currently be float16
orig_dtype = x.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment