"examples/vscode:/vscode.git/clone" did not exist on "8873fb29d9983a9719a6d3ad529731770f68f0ea"
Unverified Commit 915140fd authored by azhurkevich's avatar azhurkevich Committed by GitHub
Browse files

[NVIDIA] Add Low Latency NVFP4 decode kernels from Flashinfer (#8552)


Co-authored-by: default avatarCheng Wan <cwan@x.ai>
parent 36fc9260
......@@ -14,13 +14,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd,
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FlashInferFusedMoE,
FusedMoE,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import (
......@@ -48,7 +44,6 @@ _is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul
......@@ -741,6 +736,22 @@ class FlashInferEPMoE(EPMoE):
def get_moe_impl_class():
if global_server_args_dict["moe_a2a_backend"].is_deepep():
return DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
try:
# Check the quantization argument directly
quantization = global_server_args_dict.get("quantization")
if quantization == "modelopt_fp4":
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FlashInferFP4MoE,
)
return FlashInferFP4MoE
except:
pass
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
return FusedMoE
if get_moe_expert_parallel_world_size() > 1:
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import importlib.util
import datetime
import glob
import logging
import os
import sys
from enum import Enum
from functools import lru_cache
from typing import List, Optional, Tuple
import torch
from packaging import version as pkg_version
from sglang.srt.distributed import (
get_moe_expert_parallel_rank,
......@@ -22,6 +23,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
......@@ -29,22 +31,58 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_flashinfer_available,
is_hip,
next_power_of_2,
)
if is_flashinfer_available():
from flashinfer import (
RoutingMethodType,
fp4_quantize,
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
# Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe = None
if should_use_flashinfer_trtllm_moe():
try:
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
except ImportError:
trtllm_fp4_block_scale_moe = None
logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def should_use_flashinfer_trtllm_moe():
return global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
def _is_fp4_quantization_enabled():
"""Check if ModelOpt FP4 quantization is enabled."""
try:
# Use the same simple check that works for class selection
quantization = global_server_args_dict.get("quantization")
return quantization == "modelopt_fp4"
except:
return False
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
class FusedMoeWeightScaleSupported(Enum):
......@@ -157,10 +195,6 @@ class FusedMoE(torch.nn.Module):
)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
self.quant_method.enable_flashinfer_cutlass_moe = (
self.enable_flashinfer_cutlass_moe
)
assert self.quant_method is not None
self.quant_config = quant_config
......@@ -747,7 +781,130 @@ class FlashInferFusedMoE(FusedMoE):
routed_scaling_factor=self.routed_scaling_factor,
)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
class FlashInferFP4MoE(FusedMoE):
"""FP4 TRTLLM MoE implementation using FlashInfer."""
def __init__(self, *args, **kwargs):
# Extract DeepSeek-specific parameters
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
# Extract additional TopK parameters that were previously extracted in forward
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
super().__init__(*args, **kwargs)
# Store DeepSeek parameters
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
# ---------------------------------------------------------------------
# Helper: quantize hidden states to FP4 each forward pass
# ---------------------------------------------------------------------
def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor):
"""
Quantize hidden states using global scale factor from quantization method.
Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
Only block scales are computed at runtime for efficiency.
Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
"""
# flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
# Only the block scales are computed at runtime
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
hidden_states,
self.w13_input_scale_quant,
16, # sf_vec_size
False, # use_ue8m0
False, # is_sf_swizzled_layout
)
hs_fp4 = hs_fp4_bytes.reshape(
hidden_states.shape[0], hidden_states.shape[1] // 2
)
hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)
return hs_fp4, hs_sf
def forward(self, hidden_states: torch.Tensor, topk_output):
"""Forward pass using FP4 TRTLLM kernel.
Args:
hidden_states: Input tensor
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
"""
# TRTLLM mode expects (TopK_config, router_logits) tuple
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
router_logits = router_logits.to(torch.float32)
result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=self.correction_bias.to(hidden_states.dtype),
hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_weights=self.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
output1_scale_scalar=self.g1_scale_c.data,
output1_scale_gate_scalar=self.g1_alphas.data,
output2_scale_scalar=self.g2_alphas.data,
num_experts=self.num_experts,
top_k=self.top_k,
n_group=self.num_expert_group,
topk_group=self.topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
local_num_experts=self.num_local_experts,
routed_scaling_factor=self.routed_scaling_factor,
tile_tokens_dim=_get_tile_tokens_dim(
hidden_states.shape[0], self.top_k, self.num_local_experts
),
routing_method_type=RoutingMethodType.DeepSeekV3,
do_finalize=True,
)[0]
return result
def get_fused_moe_impl_class():
"""Factory function to get the appropriate FusedMoE implementation class."""
if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
# Use FP4 variant when FP4 quantization is enabled
return FlashInferFP4MoE
elif should_use_flashinfer_trtllm_moe():
# Use regular FlashInfer variant for non-FP4 FlashInfer cases
return FlashInferFusedMoE
else:
# Default case
return FusedMoE
import importlib.util
from enum import Enum
from functools import lru_cache
from packaging import version as pkg_version
from sglang.srt.managers.schedule_batch import global_server_args_dict
@lru_cache(maxsize=1)
def should_use_flashinfer_trtllm_moe():
result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
return result
class MoeA2ABackend(Enum):
......
......@@ -51,7 +51,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
......@@ -109,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_triton_kernel_moe",
"enable_multimodal",
"enable_symm_mem",
"quantization",
]
# Put some global args for easy access
......
......@@ -60,12 +60,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import (
DeepEPMoE,
get_moe_impl_class,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import (
......@@ -307,19 +304,15 @@ class DeepseekV2MoE(nn.Module):
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
)
self.topk = (
TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
if not should_use_flashinfer_trtllm_moe()
else None
self.topk = TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
self.experts = get_moe_impl_class()(
......@@ -476,10 +469,14 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if should_use_flashinfer_trtllm_moe():
kwargs["topk_output"] = (self.topk, router_logits)
else:
kwargs["router_logits"] = router_logits
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor
......@@ -505,10 +502,14 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if should_use_flashinfer_trtllm_moe():
kwargs["topk_output"] = (self.topk, router_logits)
else:
kwargs["router_logits"] = router_logits
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
......
......@@ -50,11 +50,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import (
get_moe_impl_class,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
......
......@@ -481,6 +481,13 @@ class ServerArgs:
self.tp_size,
], "The expert parallel size must be 1 or the same as the tensor parallel size"
if self.enable_flashinfer_trtllm_moe:
if not self.disable_shared_experts_fusion:
self.disable_shared_experts_fusion = True
logger.warning(
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
)
# DeepEP MoE
if self.moe_a2a_backend == "deepep":
if self.deepep_mode == "normal":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment