Unverified Commit 85486b6f authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

[NVIDIA] Add Flashinfer MoE blockscale fp8 backend (#8036)

parent e34cf6ad
......@@ -47,12 +47,17 @@ from sglang.srt.utils import (
get_bool_env_var,
is_hip,
is_npu,
next_power_of_2,
)
_is_hip = is_hip()
_is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
use_flashinfer_trtllm_moe = (
global_server_args_dict["enable_flashinfer_trtllm_moe"]
and global_server_args_dict["enable_ep_moe"]
)
if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul
......@@ -64,6 +69,13 @@ if _use_aiter:
from aiter.fused_moe import fused_moe
from aiter.ops.shuffle import shuffle_weight
if use_flashinfer_trtllm_moe:
try:
import flashinfer.fused_moe as fi_fused_moe
except ImportError:
fi_fused_moe = None
use_flashinfer_trtllm_moe = False
logger = logging.getLogger(__name__)
......@@ -140,6 +152,16 @@ class GroupedGemmRunner(torch.nn.Module):
return c
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
class EPMoE(torch.nn.Module):
"""
MoE Expert Parallel Impl
......@@ -776,14 +798,20 @@ class EPMoE(torch.nn.Module):
)
return
if shard_id == "w2":
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if use_flashinfer_trtllm_moe:
actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
else:
actual_shard_id = shard_id
if actual_shard_id == "w2":
param.data[expert_id] = loaded_weight
elif shard_id == "w1":
elif actual_shard_id == "w1":
param.data[expert_id][: self.intermediate_size, :] = loaded_weight
elif shard_id == "w3":
elif actual_shard_id == "w3":
param.data[expert_id][self.intermediate_size :, :] = loaded_weight
else:
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}")
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {actual_shard_id}")
def _load_fp8_scale(
self,
......@@ -820,12 +848,18 @@ class EPMoE(torch.nn.Module):
# Weight scales
elif "weight_scale" in weight_name:
if self.use_block_quant:
if use_flashinfer_trtllm_moe:
actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
else:
actual_shard_id = shard_id
block_n, block_k = self.block_shape[0], self.block_shape[1]
if shard_id == "w1":
if actual_shard_id == "w1":
param_data[expert_id][
: (self.intermediate_size + block_n - 1) // block_n, :
] = loaded_weight
elif shard_id == "w3":
elif actual_shard_id == "w3":
param_data[expert_id][
(self.intermediate_size + block_n - 1) // block_n :, :
] = loaded_weight
......@@ -1315,12 +1349,73 @@ class DeepEPMoE(EPMoE):
return down_output
class FlashInferEPMoE(EPMoE):
def __init__(self, *args, **kwargs):
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
super().__init__(*args, **kwargs)
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert use_flashinfer_trtllm_moe
assert (
self.activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe"
assert (
self.renormalize
), "Renormalize is required for flashinfer blockscale fp8 moe"
assert (
self.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
assert fi_fused_moe is not None
return fi_fused_moe.trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
routing_bias=self.correction_bias.to(hidden_states.dtype),
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=self.w13_weight,
gemm1_weights_scale=self.w13_weight_scale_inv,
gemm2_weights=self.w2_weight,
gemm2_weights_scale=self.w2_weight_scale_inv,
num_experts=self.num_experts,
top_k=self.top_k,
n_group=self.num_expert_group,
topk_group=self.topk_group,
intermediate_size=self.w2_weight.shape[2],
local_expert_offset=self.start_expert_id,
local_num_experts=self.num_experts_per_partition,
routed_scaling_factor=self.routed_scaling_factor,
tile_tokens_dim=_get_tile_tokens_dim(
hidden_states.shape[0], self.top_k, self.num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
def get_moe_impl_class():
if global_server_args_dict["enable_deepep_moe"]:
return DeepEPMoE
if global_server_args_dict["enable_flashinfer_moe"]:
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
return FusedMoE
if use_flashinfer_trtllm_moe:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
return FlashInferEPMoE
if global_server_args_dict["enable_ep_moe"]:
return EPMoE
return FusedMoE
......@@ -75,7 +75,7 @@ class FusedMoE(torch.nn.Module):
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_moe: Optional[bool] = False,
enable_flashinfer_cutlass_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False,
):
super().__init__()
......@@ -92,16 +92,16 @@ class FusedMoE(torch.nn.Module):
self.num_experts = num_experts
self.expert_map = None
if enable_flashinfer_moe and quant_config is None:
if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.")
enable_flashinfer_moe = False
enable_flashinfer_cutlass_moe = False
enable_ep_moe = False
self.enable_flashinfer_moe = enable_flashinfer_moe
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
if enable_ep_moe:
assert (
self.enable_flashinfer_moe
), "FusedMoE only supports EP with --enable-flashinfer-moe"
self.enable_flashinfer_cutlass_moe
), "FusedMoE only supports EP with --enable-flashinfer-cutlass-moe"
self.ep_size = self.tp_size
self.ep_rank = self.tp_rank
self.tp_size = 1
......@@ -141,7 +141,9 @@ class FusedMoE(torch.nn.Module):
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
self.quant_method.enable_flashinfer_cutlass_moe = (
self.enable_flashinfer_cutlass_moe
)
assert self.quant_method is not None
self.quant_config = quant_config
......
......@@ -711,7 +711,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" quantization. Please use Blackwell and"
" above."
)
self.enable_flashinfer_moe = False
self.enable_flashinfer_cutlass_moe = False
def create_weights(
self,
......@@ -865,7 +865,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
if self.enable_flashinfer_moe:
if self.enable_flashinfer_cutlass_moe:
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
......@@ -894,7 +894,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
# GEMM 2
if self.enable_flashinfer_moe:
if self.enable_flashinfer_cutlass_moe:
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
else:
w2_input_scale = layer.w2_input_scale
......@@ -934,7 +934,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
@property
def load_up_proj_weight_first(self) -> bool:
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
return self.enable_flashinfer_moe
return self.enable_flashinfer_cutlass_moe
def apply(
self,
......@@ -954,7 +954,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if self.enable_flashinfer_moe:
if self.enable_flashinfer_cutlass_moe:
assert (
not apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer"
......
......@@ -88,7 +88,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_deepep_moe",
"deepep_mode",
"enable_ep_moe",
"enable_flashinfer_moe",
"enable_flashinfer_cutlass_moe",
"enable_flashinfer_trtllm_moe",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
......
......@@ -56,7 +56,11 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.layer import (
DeepEPMoE,
get_moe_impl_class,
use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization import deep_gemm_wrapper
......@@ -302,15 +306,19 @@ class DeepseekV2MoE(nn.Module):
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
)
self.topk = TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
self.topk = (
TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
if not use_flashinfer_trtllm_moe
else None
)
self.experts = get_moe_impl_class()(
......@@ -332,10 +340,22 @@ class DeepseekV2MoE(nn.Module):
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_moe=True,
enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
)
if global_server_args_dict["enable_flashinfer_moe"]
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
**(
dict(
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
if use_flashinfer_trtllm_moe
else {}
),
)
......@@ -455,10 +475,12 @@ class DeepseekV2MoE(nn.Module):
with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(
hidden_states=hidden_states, topk_output=topk_output
)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
final_hidden_states = self.experts(**kwargs)
if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream)
......@@ -478,10 +500,12 @@ class DeepseekV2MoE(nn.Module):
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(
hidden_states=hidden_states, topk_output=topk_output
)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
final_hidden_states = self.experts(**kwargs)
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
......
......@@ -147,10 +147,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_moe=True,
enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
)
if global_server_args_dict["enable_flashinfer_moe"]
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
)
......
......@@ -120,10 +120,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_moe=True,
enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
)
if global_server_args_dict["enable_flashinfer_moe"]
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
)
......
......@@ -169,7 +169,8 @@ class ServerArgs:
ep_size: int = 1
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
enable_flashinfer_moe: bool = False
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
ep_num_redundant_experts: int = 0
......@@ -428,12 +429,16 @@ class ServerArgs:
), "Please enable dp attention when setting enable_dp_lm_head. "
# MoE kernel
if self.enable_flashinfer_moe:
if self.enable_flashinfer_cutlass_moe:
assert (
self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE"
os.environ["TRTLLM_ENABLE_PDL"] = "1"
if self.enable_flashinfer_trtllm_moe:
assert self.enable_ep_moe, "EP MoE is required for Flashinfer TRTLLM MOE"
logger.warning(f"Flashinfer TRTLLM MoE is enabled.")
# DeepEP MoE
if self.enable_deepep_moe:
if self.deepep_mode == "normal":
......@@ -1293,10 +1298,15 @@ class ServerArgs:
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
)
parser.add_argument(
"--enable-flashinfer-moe",
"--enable-flashinfer-cutlass-moe",
action="store_true",
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe",
)
parser.add_argument(
"--enable-flashinfer-allreduce-fusion",
action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment