Unverified Commit 915140fd authored by azhurkevich's avatar azhurkevich Committed by GitHub
Browse files

[NVIDIA] Add Low Latency NVFP4 decode kernels from Flashinfer (#8552)


Co-authored-by: default avatarCheng Wan <cwan@x.ai>
parent 36fc9260
...@@ -14,13 +14,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -14,13 +14,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd, silu_and_mul_masked_post_quant_fwd,
tma_align_input_scale, tma_align_input_scale,
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import ( from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
FlashInferFusedMoE,
FusedMoE,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import ( from sglang.srt.layers.quantization.fp8 import (
...@@ -48,7 +44,6 @@ _is_npu = is_npu() ...@@ -48,7 +44,6 @@ _is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if not (_is_npu or _is_hip): if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul from sgl_kernel import silu_and_mul
...@@ -741,6 +736,22 @@ class FlashInferEPMoE(EPMoE): ...@@ -741,6 +736,22 @@ class FlashInferEPMoE(EPMoE):
def get_moe_impl_class(): def get_moe_impl_class():
if global_server_args_dict["moe_a2a_backend"].is_deepep(): if global_server_args_dict["moe_a2a_backend"].is_deepep():
return DeepEPMoE return DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
try:
# Check the quantization argument directly
quantization = global_server_args_dict.get("quantization")
if quantization == "modelopt_fp4":
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FlashInferFP4MoE,
)
return FlashInferFP4MoE
except:
pass
if global_server_args_dict["enable_flashinfer_cutlass_moe"]: if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
return FusedMoE return FusedMoE
if get_moe_expert_parallel_world_size() > 1: if get_moe_expert_parallel_world_size() > 1:
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import importlib.util import datetime
import glob
import logging import logging
import os
import sys
from enum import Enum from enum import Enum
from functools import lru_cache
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from packaging import version as pkg_version
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_rank, get_moe_expert_parallel_rank,
...@@ -22,6 +23,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( ...@@ -22,6 +23,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
) )
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import StandardTopKOutput from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
...@@ -29,22 +31,58 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -29,22 +31,58 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_flashinfer_available,
is_hip,
next_power_of_2,
)
if is_flashinfer_available():
from flashinfer import (
RoutingMethodType,
fp4_quantize,
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
_is_hip = is_hip() _is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
# Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe = None
if should_use_flashinfer_trtllm_moe():
try:
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
except ImportError:
trtllm_fp4_block_scale_moe = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@lru_cache(maxsize=1) def _is_fp4_quantization_enabled():
def should_use_flashinfer_trtllm_moe(): """Check if ModelOpt FP4 quantization is enabled."""
return global_server_args_dict["enable_flashinfer_trtllm_moe"] and ( try:
not importlib.util.find_spec("flashinfer") # Use the same simple check that works for class selection
or pkg_version.parse(__import__("flashinfer").__version__) quantization = global_server_args_dict.get("quantization")
>= pkg_version.parse("0.2.9rc1") return quantization == "modelopt_fp4"
) except:
return False
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
class FusedMoeWeightScaleSupported(Enum): class FusedMoeWeightScaleSupported(Enum):
...@@ -157,10 +195,6 @@ class FusedMoE(torch.nn.Module): ...@@ -157,10 +195,6 @@ class FusedMoE(torch.nn.Module):
) )
else: else:
self.quant_method = quant_config.get_quant_method(self, prefix) self.quant_method = quant_config.get_quant_method(self, prefix)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
self.quant_method.enable_flashinfer_cutlass_moe = (
self.enable_flashinfer_cutlass_moe
)
assert self.quant_method is not None assert self.quant_method is not None
self.quant_config = quant_config self.quant_config = quant_config
...@@ -747,7 +781,130 @@ class FlashInferFusedMoE(FusedMoE): ...@@ -747,7 +781,130 @@ class FlashInferFusedMoE(FusedMoE):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
) )
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states return final_hidden_states
class FlashInferFP4MoE(FusedMoE):
"""FP4 TRTLLM MoE implementation using FlashInfer."""
def __init__(self, *args, **kwargs):
# Extract DeepSeek-specific parameters
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
# Extract additional TopK parameters that were previously extracted in forward
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
super().__init__(*args, **kwargs)
# Store DeepSeek parameters
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
# ---------------------------------------------------------------------
# Helper: quantize hidden states to FP4 each forward pass
# ---------------------------------------------------------------------
def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor):
"""
Quantize hidden states using global scale factor from quantization method.
Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
Only block scales are computed at runtime for efficiency.
Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
"""
# flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
# Only the block scales are computed at runtime
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
hidden_states,
self.w13_input_scale_quant,
16, # sf_vec_size
False, # use_ue8m0
False, # is_sf_swizzled_layout
)
hs_fp4 = hs_fp4_bytes.reshape(
hidden_states.shape[0], hidden_states.shape[1] // 2
)
hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)
return hs_fp4, hs_sf
def forward(self, hidden_states: torch.Tensor, topk_output):
"""Forward pass using FP4 TRTLLM kernel.
Args:
hidden_states: Input tensor
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
"""
# TRTLLM mode expects (TopK_config, router_logits) tuple
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
router_logits = router_logits.to(torch.float32)
result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=self.correction_bias.to(hidden_states.dtype),
hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_weights=self.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
output1_scale_scalar=self.g1_scale_c.data,
output1_scale_gate_scalar=self.g1_alphas.data,
output2_scale_scalar=self.g2_alphas.data,
num_experts=self.num_experts,
top_k=self.top_k,
n_group=self.num_expert_group,
topk_group=self.topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
local_num_experts=self.num_local_experts,
routed_scaling_factor=self.routed_scaling_factor,
tile_tokens_dim=_get_tile_tokens_dim(
hidden_states.shape[0], self.top_k, self.num_local_experts
),
routing_method_type=RoutingMethodType.DeepSeekV3,
do_finalize=True,
)[0]
return result
def get_fused_moe_impl_class():
"""Factory function to get the appropriate FusedMoE implementation class."""
if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
# Use FP4 variant when FP4 quantization is enabled
return FlashInferFP4MoE
elif should_use_flashinfer_trtllm_moe():
# Use regular FlashInfer variant for non-FP4 FlashInfer cases
return FlashInferFusedMoE
else:
# Default case
return FusedMoE
import importlib.util
from enum import Enum from enum import Enum
from functools import lru_cache
from packaging import version as pkg_version
from sglang.srt.managers.schedule_batch import global_server_args_dict
@lru_cache(maxsize=1)
def should_use_flashinfer_trtllm_moe():
result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
return result
class MoeA2ABackend(Enum): class MoeA2ABackend(Enum):
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
from __future__ import annotations from __future__ import annotations
import importlib.util
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -29,6 +31,7 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -29,6 +31,7 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import is_cuda, next_power_of_2 from sglang.srt.utils import is_cuda, next_power_of_2
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -39,6 +42,11 @@ if is_cuda(): ...@@ -39,6 +42,11 @@ if is_cuda():
try: try:
from flashinfer import mm_fp4 as fp4_gemm from flashinfer import mm_fp4 as fp4_gemm
from flashinfer import (
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
enable_flashinfer_fp4_gemm = True enable_flashinfer_fp4_gemm = True
except ImportError: except ImportError:
...@@ -47,6 +55,9 @@ except ImportError: ...@@ -47,6 +55,9 @@ except ImportError:
else: else:
fp4_gemm = None fp4_gemm = None
enable_flashinfer_fp4_gemm = False enable_flashinfer_fp4_gemm = False
reorder_rows_for_gated_act_gemm = None
shuffle_matrix_a = None
shuffle_matrix_sf_a = None
try: try:
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
...@@ -527,6 +538,7 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -527,6 +538,7 @@ class ModelOptFp4Config(QuantizationConfig):
) -> Optional[QuantizeMethodBase]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded( if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
...@@ -536,6 +548,9 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -536,6 +548,9 @@ class ModelOptFp4Config(QuantizationConfig):
return ModelOptFp4LinearMethod(self) return ModelOptFp4LinearMethod(self)
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention): if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self) return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FlashInferFP4MoE):
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
return ModelOptNvFp4FusedMoEMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoEMethod(self) return ModelOptNvFp4FusedMoEMethod(self)
return None return None
...@@ -726,7 +741,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -726,7 +741,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" quantization. Please use Blackwell and" " quantization. Please use Blackwell and"
" above." " above."
) )
self.enable_flashinfer_cutlass_moe = False self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
@property
def enable_flashinfer_cutlass_moe(self) -> bool:
"""Access the global enable_flashinfer_cutlass_moe setting."""
return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
def create_weights( def create_weights(
self, self,
...@@ -743,16 +763,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -743,16 +763,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" dynamic quantization is not supported." " dynamic quantization is not supported."
) )
# TODO(ch-wan): check if this is needed
layer.num_experts = num_experts layer.num_experts = num_experts
layer.num_local_experts = num_experts
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.params_dtype = params_dtype layer.params_dtype = params_dtype
layer.quant_config = self.quant_config layer.quant_config = self.quant_config
weight_dtype = torch.uint8 weight_dtype = torch.uint8
weight_scale_dtype = torch.float8_e4m3fn weight_scale_dtype = torch.float8_e4m3fn
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
# GEMM 1 # GEMM 1
w13_weight = ModelWeightParameter( w13_weight = ModelWeightParameter(
data=torch.empty( data=torch.empty(
num_experts, layer.local_num_experts,
2 * intermediate_size_per_partition, 2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension # 2 fp4 items are packed in the input dimension
hidden_size // 2, hidden_size // 2,
...@@ -767,7 +791,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -767,7 +791,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
# GEMM 2 # GEMM 2
w2_weight = ModelWeightParameter( w2_weight = ModelWeightParameter(
data=torch.empty( data=torch.empty(
num_experts, layer.num_local_experts,
hidden_size, hidden_size,
# 2 fp4 items are packed in the input dimension # 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2, intermediate_size_per_partition // 2,
...@@ -781,7 +805,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -781,7 +805,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_weight_scale = ModelWeightParameter( w13_weight_scale = ModelWeightParameter(
data=torch.empty( data=torch.empty(
num_experts, layer.num_local_experts,
2 * intermediate_size_per_partition, 2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension # 2 fp4 items are packed in the input dimension
hidden_size // self.quant_config.group_size, hidden_size // self.quant_config.group_size,
...@@ -795,7 +819,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -795,7 +819,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w2_weight_scale = ModelWeightParameter( w2_weight_scale = ModelWeightParameter(
data=torch.empty( data=torch.empty(
num_experts, layer.num_local_experts,
hidden_size, hidden_size,
# 2 fp4 items are packed in the input dimension # 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // self.quant_config.group_size, intermediate_size_per_partition // self.quant_config.group_size,
...@@ -814,13 +838,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -814,13 +838,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) )
w13_weight_scale_2 = PerTensorScaleParameter( w13_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, 2, dtype=torch.float32), data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
weight_loader=weight_loader, weight_loader=weight_loader,
) )
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
w2_weight_scale_2 = PerTensorScaleParameter( w2_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, dtype=torch.float32), data=torch.empty(layer.num_local_experts, dtype=torch.float32),
weight_loader=weight_loader, weight_loader=weight_loader,
) )
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
...@@ -830,18 +854,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -830,18 +854,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) )
w13_input_scale = PerTensorScaleParameter( w13_input_scale = PerTensorScaleParameter(
data=torch.empty(num_experts, 2, dtype=torch.float32), data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
weight_loader=weight_loader, weight_loader=weight_loader,
) )
layer.register_parameter("w13_input_scale", w13_input_scale) layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = PerTensorScaleParameter( w2_input_scale = PerTensorScaleParameter(
data=torch.empty(num_experts, dtype=torch.float32), data=torch.empty(layer.num_local_experts, dtype=torch.float32),
weight_loader=weight_loader, weight_loader=weight_loader,
) )
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
def swizzle_blockscale(self, scale: torch.tensor): def swizzle_blockscale(self, scale: torch.Tensor):
assert scale.dtype == torch.float8_e4m3fn assert scale.dtype == torch.float8_e4m3fn
# Pad and blockwise interleave weight_scale # Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim scale_ndim = scale.ndim
...@@ -866,9 +890,125 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -866,9 +890,125 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
else swizzled_scale.reshape(B, M, K) else swizzled_scale.reshape(B, M, K)
) )
def prepare_static_weights_for_kernel(
self,
# args_dequant,
# args,
gemm1_weights,
gemm2_weights,
gemm1_scales_linear_fp4_bytes,
gemm2_scales_linear_fp4_bytes,
hidden_size,
intermediate_size,
num_experts,
):
from flashinfer import (
RoutingMethodType,
e2m1_and_ufp8sf_scale_to_float,
fp4_quantize,
next_positive_power_of_2,
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
"""Prepare quantized weights for kernel (done offline with weights)."""
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
# Convert quantized weights to proper formats
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2
) # packed fp4
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(
num_experts, 2 * intermediate_size, hidden_size // 16
) # fp8 scaling factors
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 2
) # packed fp4
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(
num_experts, hidden_size, intermediate_size // 16
) # fp8 scaling factors
# Reorder rows of W1 and scales for fused gated activation
gemm1_weights_fp4_interleaved = []
gemm1_scales_fp4_interleaved = []
for i in range(num_experts):
gemm1_weights_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
)
gemm1_scales_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
)
# Stack weights and scales for all experts
gemm1_weights_fp4_interleaved = torch.stack(
gemm1_weights_fp4_interleaved
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
gemm1_scales_fp4_interleaved = torch.stack(
gemm1_scales_fp4_interleaved
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_fp4_shuffled = []
gemm1_scales_fp4_shuffled = []
gemm2_weights_fp4_shuffled = []
gemm2_scales_fp4_shuffled = []
for i in range(num_experts):
gemm1_weights_fp4_shuffled.append(
shuffle_matrix_a(
gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
)
)
gemm1_scales_fp4_shuffled.append(
shuffle_matrix_sf_a(
gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
)
)
gemm2_weights_fp4_shuffled.append(
shuffle_matrix_a(
gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
)
)
gemm2_scales_fp4_shuffled.append(
shuffle_matrix_sf_a(
gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
)
)
# Stack weights for all experts
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
gemm1_scales_fp4_shuffled = (
torch.stack(gemm1_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
)
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
gemm2_scales_fp4_shuffled = (
torch.stack(gemm2_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(num_experts, hidden_size, intermediate_size // 16)
)
return (
gemm1_weights_fp4_shuffled,
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process FP4 MoE weights after loading from serialized checkpoint.
# GEMM 1 Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
# GEMM 1 scale processing
if not torch.allclose( if not torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
): ):
...@@ -880,65 +1020,115 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -880,65 +1020,115 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
if self.enable_flashinfer_cutlass_moe: # Calculate input scales based on strategy
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
w13_input_scale = layer.w13_input_scale.max().to(torch.float32) w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
else: else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
w2_input_scale = layer.w2_input_scale
# Create shared parameters
layer.g1_alphas = Parameter( layer.g1_alphas = Parameter(
(w13_input_scale * w13_weight_scale_2).to(torch.float32), (w13_input_scale * w13_weight_scale_2).to(torch.float32),
requires_grad=False, requires_grad=False,
) )
layer.g2_alphas = Parameter(
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False,
)
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False
)
layer.w2_input_scale_quant = Parameter(
(1 / w2_input_scale).to(torch.float32), requires_grad=False
)
# Validate weight scales
for name, weight_scale in [
("w13", layer.w13_weight_scale),
("w2", layer.w2_weight_scale),
]:
assert ( assert (
layer.w13_weight_scale.shape[2] % 16 == 0 weight_scale.shape[2] % 16 == 0
), "Expected weight_scale.dim(1) to be divisible by 16" ), f"Expected {name}_weight_scale.dim(2) to be divisible by 16"
assert ( assert (
layer.w13_weight_scale.dtype == torch.float8_e4m3fn weight_scale.dtype == torch.float8_e4m3fn
), "Weight Blockscale must be represented as FP8-E4M3" ), f"{name} Weight Blockscale must be represented as FP8-E4M3"
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
# Weight processing based on strategy
if (
self.enable_flashinfer_trtllm_moe
and reorder_rows_for_gated_act_gemm is not None
and shuffle_matrix_sf_a is not None
):
# FlashInfer TRTLLM processing - handles both w13 and w2
(
gemm1_weights_fp4_shuffled,
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
) = self.prepare_static_weights_for_kernel(
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
layer.w2_weight.size(-2), # hidden_size
layer.w13_weight.size(-2) // 2, # intermediate_size
layer.w13_weight.size(0), # num_experts
)
layer.w13_blockscale_swizzled = Parameter( # Set flashinfer parameters
w13_blockscale_swizzled, requires_grad=False layer.gemm1_weights_fp4_shuffled = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False
)
layer.gemm2_weights_fp4_shuffled = Parameter(
gemm2_weights_fp4_shuffled, requires_grad=False
)
layer.gemm1_scales_fp4_shuffled = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False
)
layer.gemm2_scales_fp4_shuffled = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False
) )
del layer.w13_weight_scale
# This is for quantization, so we need to invert it. # Additional parameter needed for TRT-LLM
layer.w13_input_scale_quant = Parameter( layer.g1_scale_c = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False,
) )
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) # Clean up weights that won't be used by TRT-LLM
del (
layer.w2_weight,
layer.w2_weight_scale,
layer.w13_weight,
layer.w13_weight_scale,
)
# GEMM 2 print("Applied flashinfer weight processing for both w13 and w2")
if self.enable_flashinfer_cutlass_moe:
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
else:
w2_input_scale = layer.w2_input_scale
layer.g2_alphas = Parameter( else:
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), # CUTLASS processing - handle w13 and w2 separately
requires_grad=False,
)
# This is for quantization, so we need to invert it. # Process w13 weights
layer.w2_input_scale_quant = Parameter( w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
(1 / w2_input_scale).to(torch.float32), requires_grad=False layer.w13_blockscale_swizzled = Parameter(
w13_blockscale_swizzled, requires_grad=False
) )
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
assert ( # Process w2 weights
layer.w2_weight_scale.shape[2] % 16 == 0
), "Expected weight_scale.dim(1) to be divisible by 16"
assert (
layer.w2_weight_scale.dtype == torch.float8_e4m3fn
), "Weight Blockscale must be represented as FP8-E4M3"
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
layer.w2_blockscale_swizzled = Parameter( layer.w2_blockscale_swizzled = Parameter(
w2_blockscale_swizzled, requires_grad=False w2_blockscale_swizzled, requires_grad=False
) )
del layer.w2_weight_scale
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
# Both flashinfer cutlass and regular cutlass use same processing for w2
print("Applied weight processing for both w13 and w2")
# Set up CUTLASS MoE parameters
device = layer.w13_weight.device device = layer.w13_weight.device
layer.cutlass_moe_params = CutlassMoEParams( layer.cutlass_moe_params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4, CutlassMoEType.BlockscaledFP4,
...@@ -971,13 +1161,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -971,13 +1161,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
return layer.forward(x, topk_output)
if self.enable_flashinfer_cutlass_moe: if self.enable_flashinfer_cutlass_moe:
assert ( assert (
not apply_router_weight_on_input not apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer" ), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint # and fp4 quantized weights loaded from the checkpoint
topk_weights, topk_ids, _ = topk_output
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
output = flashinfer_cutlass_fused_moe( output = flashinfer_cutlass_fused_moe(
x, x,
topk_ids.to(torch.int), topk_ids.to(torch.int),
...@@ -1005,7 +1202,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1005,7 +1202,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
output = cutlass_moe_fp4( output = cutlass_moe_fp4(
a=x, a=x,
a1_gscale=layer.w13_input_scale_quant, a1_gscale=layer.w13_input_scale_quant,
......
...@@ -51,7 +51,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ...@@ -51,7 +51,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin, ScheduleBatchDisaggregationDecodeMixin,
) )
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.mem_cache.allocator import ( from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator, BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator, SWATokenToKVPoolAllocator,
...@@ -109,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -109,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_triton_kernel_moe", "enable_triton_kernel_moe",
"enable_multimodal", "enable_multimodal",
"enable_symm_mem", "enable_symm_mem",
"quantization",
] ]
# Put some global args for easy access # Put some global args for easy access
......
...@@ -60,12 +60,9 @@ from sglang.srt.layers.linear import ( ...@@ -60,12 +60,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import ( from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
DeepEPMoE,
get_moe_impl_class,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
...@@ -307,8 +304,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -307,8 +304,7 @@ class DeepseekV2MoE(nn.Module):
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
) )
self.topk = ( self.topk = TopK(
TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
use_grouped_topk=True, use_grouped_topk=True,
...@@ -318,9 +314,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -318,9 +314,6 @@ class DeepseekV2MoE(nn.Module):
correction_bias=self.gate.e_score_correction_bias, correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
) )
if not should_use_flashinfer_trtllm_moe()
else None
)
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class()(
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
...@@ -476,10 +469,14 @@ class DeepseekV2MoE(nn.Module): ...@@ -476,10 +469,14 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits) # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if should_use_flashinfer_trtllm_moe():
kwargs["topk_output"] = (self.topk, router_logits)
else: else:
kwargs["router_logits"] = router_logits kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs) final_hidden_states = self.experts(**kwargs)
if not _is_cuda: if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
...@@ -505,10 +502,14 @@ class DeepseekV2MoE(nn.Module): ...@@ -505,10 +502,14 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits) # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if should_use_flashinfer_trtllm_moe():
kwargs["topk_output"] = (self.topk, router_logits)
else: else:
kwargs["router_logits"] = router_logits kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs) final_hidden_states = self.experts(**kwargs)
if not _is_cuda and not _use_aiter: if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here # fused in biased_grouped_topk so we can skip here
......
...@@ -50,11 +50,9 @@ from sglang.srt.layers.linear import ( ...@@ -50,11 +50,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import ( from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
get_moe_impl_class,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz, is_fp8_fnuz,
......
...@@ -481,6 +481,13 @@ class ServerArgs: ...@@ -481,6 +481,13 @@ class ServerArgs:
self.tp_size, self.tp_size,
], "The expert parallel size must be 1 or the same as the tensor parallel size" ], "The expert parallel size must be 1 or the same as the tensor parallel size"
if self.enable_flashinfer_trtllm_moe:
if not self.disable_shared_experts_fusion:
self.disable_shared_experts_fusion = True
logger.warning(
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
)
# DeepEP MoE # DeepEP MoE
if self.moe_a2a_backend == "deepep": if self.moe_a2a_backend == "deepep":
if self.deepep_mode == "normal": if self.deepep_mode == "normal":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment