Unverified Commit dbcf85b7 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Add --speculative-moe-runner-backend server arg (#10183)

parent 83804bc6
...@@ -258,6 +258,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -258,6 +258,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | `1.0` | Type: float | | `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | `1.0` | Type: float |
| `--speculative-token-map` | The path of the draft model's small vocab table. | `None` | Type: str | | `--speculative-token-map` | The path of the draft model's small vocab table. | `None` | Type: str |
| `--speculative-attention-mode` | Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'. | `prefill` | `prefill`, `decode` | | `--speculative-attention-mode` | Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'. | `prefill` | `prefill`, `decode` |
| `--speculative-moe-runner-backend` | MOE backend for EAGLE speculative decoding, see --moe-runner-backend for options. Same as moe runner backend if unset. | None |
## Ngram speculative decoding ## Ngram speculative decoding
| Argument | Description | Defaults | Options | | Argument | Description | Defaults | Options |
......
...@@ -67,7 +67,7 @@ SGLang supports various environment variables that can be used to configure its ...@@ -67,7 +67,7 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` | | `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
| `SGLANG_ENABLE_FLASHINFER_GEMM` | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs | `false` | | `SGLANG_ENABLE_FLASHINFER_GEMM` | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs | `false` |
| `SGLANG_SUPPORT_CUTLASS_BLOCK_FP8` | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs | `false` | | `SGLANG_SUPPORT_CUTLASS_BLOCK_FP8` | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs | `false` |
| `SGLANG_CUTLASS_MOE` | Use Cutlass FP8 MoE kernel on Blackwell GPUs | `false` | | `SGLANG_CUTLASS_MOE` (deprecated) | Use Cutlass FP8 MoE kernel on Blackwell GPUs (deprecated, use --moe-runner-backend=cutlass) | `false` |
## Distributed Computing ## Distributed Computing
......
...@@ -11,7 +11,6 @@ from sglang.srt.layers.moe.utils import ( ...@@ -11,7 +11,6 @@ from sglang.srt.layers.moe.utils import (
initialize_moe_config, initialize_moe_config,
is_tbo_enabled, is_tbo_enabled,
should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe,
) )
__all__ = [ __all__ = [
...@@ -24,7 +23,6 @@ __all__ = [ ...@@ -24,7 +23,6 @@ __all__ = [
"get_moe_a2a_backend", "get_moe_a2a_backend",
"get_moe_runner_backend", "get_moe_runner_backend",
"get_deepep_mode", "get_deepep_mode",
"should_use_flashinfer_trtllm_moe",
"should_use_flashinfer_cutlass_moe_fp4_allgather", "should_use_flashinfer_cutlass_moe_fp4_allgather",
"is_tbo_enabled", "is_tbo_enabled",
"get_tbo_token_distribution_threshold", "get_tbo_token_distribution_threshold",
......
...@@ -11,7 +11,6 @@ from sglang.srt.layers.moe import ( ...@@ -11,7 +11,6 @@ from sglang.srt.layers.moe import (
get_deepep_mode, get_deepep_mode,
get_moe_a2a_backend, get_moe_a2a_backend,
get_moe_runner_backend, get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.token_dispatcher.deepep import ( from sglang.srt.layers.moe.token_dispatcher.deepep import (
...@@ -505,7 +504,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): ...@@ -505,7 +504,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
except: except:
pass pass
if should_use_flashinfer_trtllm_moe() and quant_config is not None: if get_moe_runner_backend().is_flashinfer_trtllm() and quant_config is not None:
# FIXME: FlashInferFusedMoE only supports fp8 quant now # FIXME: FlashInferFusedMoE only supports fp8 quant now
return FlashInferFusedMoE return FlashInferFusedMoE
if get_moe_runner_backend().is_flashinfer_cutlass(): if get_moe_runner_backend().is_flashinfer_cutlass():
......
...@@ -23,7 +23,6 @@ from sglang.srt.layers.moe import ( ...@@ -23,7 +23,6 @@ from sglang.srt.layers.moe import (
get_deepep_mode, get_deepep_mode,
get_moe_a2a_backend, get_moe_a2a_backend,
get_moe_runner_backend, get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
) )
from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher
...@@ -60,7 +59,7 @@ if is_flashinfer_available(): ...@@ -60,7 +59,7 @@ if is_flashinfer_available():
# Try to import FP4 TRTLLM function if flashinfer is available # Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe = None trtllm_fp4_block_scale_moe = None
if should_use_flashinfer_trtllm_moe(): if get_moe_runner_backend().is_flashinfer_trtllm():
try: try:
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
except ImportError: except ImportError:
...@@ -234,7 +233,7 @@ class FusedMoE(torch.nn.Module): ...@@ -234,7 +233,7 @@ class FusedMoE(torch.nn.Module):
self.quant_method, ModelOptNvFp4FusedMoEMethod self.quant_method, ModelOptNvFp4FusedMoEMethod
) or ( ) or (
isinstance(self.quant_method, Fp8MoEMethod) isinstance(self.quant_method, Fp8MoEMethod)
and self.quant_method._should_use_cutlass_fused_experts() and get_moe_runner_backend().is_cutlass()
) )
def _load_per_tensor_weight_scale( def _load_per_tensor_weight_scale(
...@@ -593,7 +592,7 @@ class FusedMoE(torch.nn.Module): ...@@ -593,7 +592,7 @@ class FusedMoE(torch.nn.Module):
) )
# Flashinfer assumes w31 format for w13_weight. Same for the scales. # Flashinfer assumes w31 format for w13_weight. Same for the scales.
if should_use_flashinfer_trtllm_moe() and ( if get_moe_runner_backend().is_flashinfer_trtllm() and (
isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
or isinstance(self.quant_method, Fp8MoEMethod) or isinstance(self.quant_method, Fp8MoEMethod)
): ):
...@@ -961,10 +960,8 @@ class FusedMoE(torch.nn.Module): ...@@ -961,10 +960,8 @@ class FusedMoE(torch.nn.Module):
class FlashInferFusedMoE(FusedMoE): class FlashInferFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert self.use_flashinfer_trtllm_moe
assert ( assert (
self.moe_runner_config.activation == "silu" self.moe_runner_config.activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe" ), "Only silu is supported for flashinfer blockscale fp8 moe"
......
...@@ -357,9 +357,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -357,9 +357,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
): ):
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
topk_ids = topk_ids.to(torch.int64) topk_ids = topk_ids.to(torch.int64)
if ( if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and not (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM get_moe_runner_backend().is_cutlass()
and not get_moe_runner_backend().is_cutlass() and self.quant_config.get_name() == "w4afp8"
): ):
# TODO hard code 128 block quant,use fp8 communication # TODO hard code 128 block quant,use fp8 communication
hidden_states = sglang_per_token_group_quant_fp8( hidden_states = sglang_per_token_group_quant_fp8(
......
...@@ -38,10 +38,7 @@ from sglang.srt.eplb.expert_location_dispatch import ( ...@@ -38,10 +38,7 @@ from sglang.srt.eplb.expert_location_dispatch import (
ExpertLocationDispatchInfo, ExpertLocationDispatchInfo,
topk_ids_logical_to_physical, topk_ids_logical_to_physical,
) )
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import get_moe_runner_backend
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
...@@ -257,7 +254,7 @@ class TopK(CustomOp): ...@@ -257,7 +254,7 @@ class TopK(CustomOp):
elif get_moe_runner_backend().is_triton_kernels(): elif get_moe_runner_backend().is_triton_kernels():
output_format = TopKOutputFormat.TRITON_KERNEL output_format = TopKOutputFormat.TRITON_KERNEL
elif ( elif (
should_use_flashinfer_trtllm_moe() get_moe_runner_backend().is_flashinfer_trtllm()
or get_moe_runner_backend().is_flashinfer_mxfp4() or get_moe_runner_backend().is_flashinfer_mxfp4()
): ):
output_format = TopKOutputFormat.BYPASSED output_format = TopKOutputFormat.BYPASSED
......
from __future__ import annotations from __future__ import annotations
import importlib.util
import logging import logging
from contextlib import contextmanager
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from packaging import version as pkg_version
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_dp_size, get_attention_dp_size,
...@@ -119,6 +117,7 @@ class DeepEPMode(Enum): ...@@ -119,6 +117,7 @@ class DeepEPMode(Enum):
MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
SPECULATIVE_MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
DEEPEP_MODE: Optional[DeepEPMode] = None DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED: Optional[bool] = None IS_TBO_ENABLED: Optional[bool] = None
IS_SBO_ENABLED: Optional[bool] = None IS_SBO_ENABLED: Optional[bool] = None
...@@ -130,6 +129,7 @@ DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None ...@@ -130,6 +129,7 @@ DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
def initialize_moe_config(server_args: ServerArgs): def initialize_moe_config(server_args: ServerArgs):
global MOE_A2A_BACKEND global MOE_A2A_BACKEND
global MOE_RUNNER_BACKEND global MOE_RUNNER_BACKEND
global SPECULATIVE_MOE_RUNNER_BACKEND
global DEEPEP_MODE global DEEPEP_MODE
global DEEPEP_CONFIG global DEEPEP_CONFIG
global IS_TBO_ENABLED global IS_TBO_ENABLED
...@@ -139,6 +139,11 @@ def initialize_moe_config(server_args: ServerArgs): ...@@ -139,6 +139,11 @@ def initialize_moe_config(server_args: ServerArgs):
MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend) MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend) MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
SPECULATIVE_MOE_RUNNER_BACKEND = (
MoeRunnerBackend(server_args.speculative_moe_runner_backend)
if server_args.speculative_moe_runner_backend is not None
else MOE_RUNNER_BACKEND
)
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode) DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
DEEPEP_CONFIG = server_args.deepep_config or "" DEEPEP_CONFIG = server_args.deepep_config or ""
IS_TBO_ENABLED = server_args.enable_two_batch_overlap IS_TBO_ENABLED = server_args.enable_two_batch_overlap
...@@ -167,6 +172,16 @@ def get_moe_runner_backend() -> MoeRunnerBackend: ...@@ -167,6 +172,16 @@ def get_moe_runner_backend() -> MoeRunnerBackend:
return MOE_RUNNER_BACKEND return MOE_RUNNER_BACKEND
def get_speculative_moe_runner_backend() -> MoeRunnerBackend:
global SPECULATIVE_MOE_RUNNER_BACKEND
if SPECULATIVE_MOE_RUNNER_BACKEND is None:
logger.warning(
"SPECULATIVE_MOE_RUNNER_BACKEND is not initialized, using auto backend"
)
SPECULATIVE_MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
return SPECULATIVE_MOE_RUNNER_BACKEND
def get_deepep_mode() -> DeepEPMode: def get_deepep_mode() -> DeepEPMode:
global DEEPEP_MODE global DEEPEP_MODE
if DEEPEP_MODE is None: if DEEPEP_MODE is None:
...@@ -207,16 +222,6 @@ def get_tbo_token_distribution_threshold() -> float: ...@@ -207,16 +222,6 @@ def get_tbo_token_distribution_threshold() -> float:
return TBO_TOKEN_DISTRIBUTION_THRESHOLD return TBO_TOKEN_DISTRIBUTION_THRESHOLD
@lru_cache(maxsize=1)
def should_use_flashinfer_trtllm_moe():
result = get_moe_runner_backend().is_flashinfer_trtllm() and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
return result
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def should_use_flashinfer_cutlass_moe_fp4_allgather(): def should_use_flashinfer_cutlass_moe_fp4_allgather():
""" """
...@@ -228,3 +233,18 @@ def should_use_flashinfer_cutlass_moe_fp4_allgather(): ...@@ -228,3 +233,18 @@ def should_use_flashinfer_cutlass_moe_fp4_allgather():
and is_dp_attention_enabled() and is_dp_attention_enabled()
and get_moe_expert_parallel_world_size() == get_attention_dp_size() and get_moe_expert_parallel_world_size() == get_attention_dp_size()
) )
@contextmanager
def speculative_moe_backend_context():
"""
Context manager to temporarily use the speculative MoE backend for draft model operations.
This ensures that draft models in speculative decoding use the configured speculative backend.
"""
global MOE_RUNNER_BACKEND
original_backend = MOE_RUNNER_BACKEND
try:
MOE_RUNNER_BACKEND = get_speculative_moe_runner_backend()
yield
finally:
MOE_RUNNER_BACKEND = original_backend
...@@ -528,7 +528,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -528,7 +528,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
self.cutlass_fp8_supported = cutlass_fp8_supported() if get_moe_runner_backend().is_cutlass():
assert (
cutlass_fp8_supported()
), "cutlass_fp8 MoE requires CUDA 12.0+ with SM90 or CUDA 12.4+ with SM89"
assert self.block_quant, "cutlass_fp8 MoE requires block quantization"
assert is_sm100_supported() or is_sm90_supported()
def create_weights( def create_weights(
self, self,
...@@ -636,7 +641,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -636,7 +641,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic" assert self.quant_config.activation_scheme == "dynamic"
if self._should_use_cutlass_fused_experts(): if get_moe_runner_backend().is_cutlass():
self._ensure_cutlass_buffers_initialized(layer) self._ensure_cutlass_buffers_initialized(layer)
else: else:
...@@ -1025,7 +1030,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1025,7 +1030,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if ret is not None: if ret is not None:
return StandardCombineInput(hidden_states=ret) return StandardCombineInput(hidden_states=ret)
if self._should_use_cutlass_fused_experts(): if get_moe_runner_backend().is_cutlass():
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
with use_symmetric_memory(get_tp_group()) as sm: with use_symmetric_memory(get_tp_group()) as sm:
...@@ -1122,22 +1127,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1122,22 +1127,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return self.runner.run(dispatch_output, quant_info) return self.runner.run(dispatch_output, quant_info)
def _should_use_cutlass_fused_experts(self) -> bool:
"""Decide whether to use Cutlass FP8 fused-experts path based on moe runner backend,
with env var override via `SGLANG_CUTLASS_MOE`.
"""
backend = get_moe_runner_backend()
env_force = get_bool_env_var("SGLANG_CUTLASS_MOE")
# TODO: remove env var in the future, it should be handled by moe runner backend
if env_force:
return True
return (
backend.is_flashinfer_cutlass()
and self.cutlass_fp8_supported
and self.block_quant
and (is_sm100_supported() or is_sm90_supported())
)
def _ensure_cutlass_buffers_initialized(self, layer: Module) -> None: def _ensure_cutlass_buffers_initialized(self, layer: Module) -> None:
if getattr(self, "_cutlass_buffers_ready", False): if getattr(self, "_cutlass_buffers_ready", False):
return return
......
...@@ -16,8 +16,8 @@ from sglang.srt.layers.moe import ( ...@@ -16,8 +16,8 @@ from sglang.srt.layers.moe import (
MoeRunner, MoeRunner,
MoeRunnerBackend, MoeRunnerBackend,
MoeRunnerConfig, MoeRunnerConfig,
get_moe_runner_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe,
) )
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
...@@ -526,7 +526,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -526,7 +526,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
) )
# Align FP8 weights to FlashInfer per-tensor kernel layout if enabled # Align FP8 weights to FlashInfer per-tensor kernel layout if enabled
if should_use_flashinfer_trtllm_moe(): if get_moe_runner_backend().is_flashinfer_trtllm():
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
# 1) Swap W13 halves: [Up, Gate] -> [Gate, Up] expected by FI # 1) Swap W13 halves: [Up, Gate] -> [Gate, Up] expected by FI
...@@ -568,7 +568,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -568,7 +568,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
) )
# Precompute and register per-expert output scaling factors for FI MoE # Precompute and register per-expert output scaling factors for FI MoE
if should_use_flashinfer_trtllm_moe(): if get_moe_runner_backend().is_flashinfer_trtllm():
# Note: w13_input_scale and w2_input_scale are scalar Parameters post-reduction # Note: w13_input_scale and w2_input_scale are scalar Parameters post-reduction
assert ( assert (
hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None
...@@ -620,8 +620,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -620,8 +620,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# Fast path: TRT-LLM FP8 per-tensor MoE using BYPASSED TopK routing # Fast path: TRT-LLM FP8 per-tensor MoE using BYPASSED TopK routing
from sglang.srt.layers.moe.topk import TopKOutputChecker from sglang.srt.layers.moe.topk import TopKOutputChecker
if should_use_flashinfer_trtllm_moe() and TopKOutputChecker.format_is_bypassed( if (
topk_output get_moe_runner_backend().is_flashinfer_trtllm()
and TopKOutputChecker.format_is_bypassed(topk_output)
): ):
router_logits = topk_output.router_logits router_logits = topk_output.router_logits
topk_config = topk_output.topk_config topk_config = topk_output.topk_config
...@@ -1079,7 +1080,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1079,7 +1080,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" quantization. Please use Blackwell and" " quantization. Please use Blackwell and"
" above." " above."
) )
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe() self.enable_flashinfer_trtllm_moe = (
get_moe_runner_backend().is_flashinfer_trtllm()
)
self._cache_permute_indices = {} self._cache_permute_indices = {}
@property @property
......
...@@ -74,8 +74,8 @@ from sglang.srt.layers.linear import ( ...@@ -74,8 +74,8 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
get_moe_a2a_backend, get_moe_a2a_backend,
get_moe_runner_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe,
) )
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
...@@ -503,7 +503,7 @@ class MoEGate(nn.Module): ...@@ -503,7 +503,7 @@ class MoEGate(nn.Module):
torch.bfloat16 torch.bfloat16
if quant_config is not None if quant_config is not None
and quant_config.get_name() == "modelopt_fp4" and quant_config.get_name() == "modelopt_fp4"
and should_use_flashinfer_trtllm_moe() and get_moe_runner_backend().is_flashinfer_trtllm()
else torch.float32 else torch.float32
) )
self.e_score_correction_bias = nn.Parameter( self.e_score_correction_bias = nn.Parameter(
......
...@@ -36,6 +36,7 @@ from sglang.srt.utils.common import ( ...@@ -36,6 +36,7 @@ from sglang.srt.utils.common import (
SUPPORTED_LORA_TARGET_MODULES, SUPPORTED_LORA_TARGET_MODULES,
configure_ipv6, configure_ipv6,
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var,
get_device, get_device,
get_device_memory_capacity, get_device_memory_capacity,
get_device_sm, get_device_sm,
...@@ -377,6 +378,7 @@ class ServerArgs: ...@@ -377,6 +378,7 @@ class ServerArgs:
speculative_accept_threshold_acc: float = 1.0 speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None speculative_token_map: Optional[str] = None
speculative_attention_mode: str = "prefill" speculative_attention_mode: str = "prefill"
speculative_moe_runner_backend: Optional[str] = None
# For ngram only # For ngram only
speculative_ngram_min_match_window_size: int = 1 speculative_ngram_min_match_window_size: int = 1
speculative_ngram_max_match_window_size: int = 12 speculative_ngram_max_match_window_size: int = 12
...@@ -1379,6 +1381,19 @@ class ServerArgs: ...@@ -1379,6 +1381,19 @@ class ServerArgs:
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set." "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
) )
if get_bool_env_var("SGLANG_CUTLASS_MOE"):
logger.warning(
"SGLANG_CUTLASS_MOE is deprecated, use --moe-runner-backend=cutlass and/or --speculative-moe-runner-backend=cutlass instead"
)
assert (
self.quantization == "fp8"
), "cutlass MoE is only supported with fp8 quantization"
self.moe_runner_backend = "cutlass"
if self.moe_runner_backend == "cutlass" and self.quantization == "fp8":
assert (
self.ep_size == 1
), "FP8 Cutlass MoE is only supported with ep_size == 1"
def _handle_a2a_moe(self): def _handle_a2a_moe(self):
if self.moe_a2a_backend == "deepep": if self.moe_a2a_backend == "deepep":
if self.deepep_mode == "normal": if self.deepep_mode == "normal":
...@@ -2722,6 +2737,13 @@ class ServerArgs: ...@@ -2722,6 +2737,13 @@ class ServerArgs:
help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.", help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.",
default=ServerArgs.speculative_attention_mode, default=ServerArgs.speculative_attention_mode,
) )
parser.add_argument(
"--speculative-moe-runner-backend",
type=str,
choices=MOE_RUNNER_BACKEND_CHOICES,
default=ServerArgs.speculative_moe_runner_backend,
help="Choose the runner backend for MoE in speculative decoding.",
)
# Ngram speculative decoding # Ngram speculative decoding
parser.add_argument( parser.add_argument(
"--speculative-ngram-min-match-window-size", "--speculative-ngram-min-match-window-size",
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from sglang.srt.distributed import get_tp_group from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import speculative_moe_backend_context
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.scheduler import GenerationBatchResult
...@@ -125,7 +126,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -125,7 +126,7 @@ class EAGLEWorker(TpModelWorker):
ctx = draft_tp_context(get_attention_tp_group()) ctx = draft_tp_context(get_attention_tp_group())
else: else:
ctx = empty_context() ctx = empty_context()
with ctx: with ctx, speculative_moe_backend_context():
super().__init__( super().__init__(
server_args=server_args, server_args=server_args,
gpu_id=gpu_id, gpu_id=gpu_id,
...@@ -174,7 +175,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -174,7 +175,9 @@ class EAGLEWorker(TpModelWorker):
self.draft_tp_context = ( self.draft_tp_context = (
draft_tp_context if server_args.enable_dp_attention else empty_context draft_tp_context if server_args.enable_dp_attention else empty_context
) )
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context():
self.init_attention_backend() self.init_attention_backend()
self.init_cuda_graphs() self.init_cuda_graphs()
...@@ -259,7 +262,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -259,7 +262,9 @@ class EAGLEWorker(TpModelWorker):
logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend( logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(
batch batch
) )
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context():
self.forward_draft_extend( self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
) )
...@@ -270,13 +275,17 @@ class EAGLEWorker(TpModelWorker): ...@@ -270,13 +275,17 @@ class EAGLEWorker(TpModelWorker):
can_run_cuda_graph=False, can_run_cuda_graph=False,
) )
else: else:
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context():
spec_info = self.draft(batch) spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info) self.verify(batch, spec_info)
) )
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context():
# NOTE: We should use `check_forward_draft_extend_after_decode` # NOTE: We should use `check_forward_draft_extend_after_decode`
# when DP attention is enabled, but it is slow. Skip it for now. # when DP attention is enabled, but it is slow. Skip it for now.
if ( if (
......
...@@ -6,6 +6,7 @@ from typing import List, Optional, Tuple ...@@ -6,6 +6,7 @@ from typing import List, Optional, Tuple
import torch import torch
from sglang.srt.environ import envs from sglang.srt.environ import envs
from sglang.srt.layers.moe.utils import speculative_moe_backend_context
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
...@@ -101,7 +102,7 @@ class EagleDraftWorker(BaseDraftWorker): ...@@ -101,7 +102,7 @@ class EagleDraftWorker(BaseDraftWorker):
self.req_to_token_pool, self.token_to_kv_pool_allocator = ( self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool() target_worker.get_memory_pool()
) )
with empty_context(): with empty_context(), speculative_moe_backend_context():
# Init draft worker # Init draft worker
self.draft_worker = TpModelWorker( self.draft_worker = TpModelWorker(
server_args=server_args, server_args=server_args,
...@@ -127,7 +128,9 @@ class EagleDraftWorker(BaseDraftWorker): ...@@ -127,7 +128,9 @@ class EagleDraftWorker(BaseDraftWorker):
self.draft_tp_context = ( self.draft_tp_context = (
draft_tp_context if server_args.enable_dp_attention else empty_context draft_tp_context if server_args.enable_dp_attention else empty_context
) )
with self.draft_tp_context(self.draft_runner.tp_group): with self.draft_tp_context(
self.draft_runner.tp_group
), speculative_moe_backend_context():
self.init_attention_backend() self.init_attention_backend()
self.init_cuda_graphs() self.init_cuda_graphs()
......
...@@ -3,6 +3,7 @@ from typing import Optional ...@@ -3,6 +3,7 @@ from typing import Optional
import torch import torch
from sglang.srt.layers.moe.utils import speculative_moe_backend_context
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_worker import EAGLEWorker from sglang.srt.speculative.eagle_worker import EAGLEWorker
...@@ -66,7 +67,7 @@ class StandaloneWorker(EAGLEWorker): ...@@ -66,7 +67,7 @@ class StandaloneWorker(EAGLEWorker):
self.hot_token_id = None self.hot_token_id = None
# Init draft worker # Init draft worker
with empty_context(): with empty_context(), speculative_moe_backend_context():
TpModelWorker.__init__( TpModelWorker.__init__(
self, self,
server_args=server_args, server_args=server_args,
...@@ -88,7 +89,9 @@ class StandaloneWorker(EAGLEWorker): ...@@ -88,7 +89,9 @@ class StandaloneWorker(EAGLEWorker):
self.draft_tp_context = ( self.draft_tp_context = (
draft_tp_context if server_args.enable_dp_attention else empty_context draft_tp_context if server_args.enable_dp_attention else empty_context
) )
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context():
self.init_attention_backend() self.init_attention_backend()
self.init_cuda_graphs() self.init_cuda_graphs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment