Unverified Commit 85486b6f authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

[NVIDIA] Add Flashinfer MoE blockscale fp8 backend (#8036)

parent e34cf6ad
...@@ -47,12 +47,17 @@ from sglang.srt.utils import ( ...@@ -47,12 +47,17 @@ from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
is_hip, is_hip,
is_npu, is_npu,
next_power_of_2,
) )
_is_hip = is_hip() _is_hip = is_hip()
_is_npu = is_npu() _is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
use_flashinfer_trtllm_moe = (
global_server_args_dict["enable_flashinfer_trtllm_moe"]
and global_server_args_dict["enable_ep_moe"]
)
if not (_is_npu or _is_hip): if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul from sgl_kernel import silu_and_mul
...@@ -64,6 +69,13 @@ if _use_aiter: ...@@ -64,6 +69,13 @@ if _use_aiter:
from aiter.fused_moe import fused_moe from aiter.fused_moe import fused_moe
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
if use_flashinfer_trtllm_moe:
try:
import flashinfer.fused_moe as fi_fused_moe
except ImportError:
fi_fused_moe = None
use_flashinfer_trtllm_moe = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -140,6 +152,16 @@ class GroupedGemmRunner(torch.nn.Module): ...@@ -140,6 +152,16 @@ class GroupedGemmRunner(torch.nn.Module):
return c return c
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
class EPMoE(torch.nn.Module): class EPMoE(torch.nn.Module):
""" """
MoE Expert Parallel Impl MoE Expert Parallel Impl
...@@ -776,14 +798,20 @@ class EPMoE(torch.nn.Module): ...@@ -776,14 +798,20 @@ class EPMoE(torch.nn.Module):
) )
return return
if shard_id == "w2": # Flashinfer assumes w31 format for w13_weight. Same for the scales.
if use_flashinfer_trtllm_moe:
actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
else:
actual_shard_id = shard_id
if actual_shard_id == "w2":
param.data[expert_id] = loaded_weight param.data[expert_id] = loaded_weight
elif shard_id == "w1": elif actual_shard_id == "w1":
param.data[expert_id][: self.intermediate_size, :] = loaded_weight param.data[expert_id][: self.intermediate_size, :] = loaded_weight
elif shard_id == "w3": elif actual_shard_id == "w3":
param.data[expert_id][self.intermediate_size :, :] = loaded_weight param.data[expert_id][self.intermediate_size :, :] = loaded_weight
else: else:
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") raise ValueError(f"Expected shard_id w1,w2 or w3 but got {actual_shard_id}")
def _load_fp8_scale( def _load_fp8_scale(
self, self,
...@@ -820,12 +848,18 @@ class EPMoE(torch.nn.Module): ...@@ -820,12 +848,18 @@ class EPMoE(torch.nn.Module):
# Weight scales # Weight scales
elif "weight_scale" in weight_name: elif "weight_scale" in weight_name:
if self.use_block_quant: if self.use_block_quant:
if use_flashinfer_trtllm_moe:
actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
else:
actual_shard_id = shard_id
block_n, block_k = self.block_shape[0], self.block_shape[1] block_n, block_k = self.block_shape[0], self.block_shape[1]
if shard_id == "w1":
if actual_shard_id == "w1":
param_data[expert_id][ param_data[expert_id][
: (self.intermediate_size + block_n - 1) // block_n, : : (self.intermediate_size + block_n - 1) // block_n, :
] = loaded_weight ] = loaded_weight
elif shard_id == "w3": elif actual_shard_id == "w3":
param_data[expert_id][ param_data[expert_id][
(self.intermediate_size + block_n - 1) // block_n :, : (self.intermediate_size + block_n - 1) // block_n :, :
] = loaded_weight ] = loaded_weight
...@@ -1315,12 +1349,73 @@ class DeepEPMoE(EPMoE): ...@@ -1315,12 +1349,73 @@ class DeepEPMoE(EPMoE):
return down_output return down_output
class FlashInferEPMoE(EPMoE):
def __init__(self, *args, **kwargs):
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
super().__init__(*args, **kwargs)
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert use_flashinfer_trtllm_moe
assert (
self.activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe"
assert (
self.renormalize
), "Renormalize is required for flashinfer blockscale fp8 moe"
assert (
self.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
assert fi_fused_moe is not None
return fi_fused_moe.trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
routing_bias=self.correction_bias.to(hidden_states.dtype),
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=self.w13_weight,
gemm1_weights_scale=self.w13_weight_scale_inv,
gemm2_weights=self.w2_weight,
gemm2_weights_scale=self.w2_weight_scale_inv,
num_experts=self.num_experts,
top_k=self.top_k,
n_group=self.num_expert_group,
topk_group=self.topk_group,
intermediate_size=self.w2_weight.shape[2],
local_expert_offset=self.start_expert_id,
local_num_experts=self.num_experts_per_partition,
routed_scaling_factor=self.routed_scaling_factor,
tile_tokens_dim=_get_tile_tokens_dim(
hidden_states.shape[0], self.top_k, self.num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
def get_moe_impl_class(): def get_moe_impl_class():
if global_server_args_dict["enable_deepep_moe"]: if global_server_args_dict["enable_deepep_moe"]:
return DeepEPMoE return DeepEPMoE
if global_server_args_dict["enable_flashinfer_moe"]: if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe # Must come before EPMoE because FusedMoE also supports enable_ep_moe
return FusedMoE return FusedMoE
if use_flashinfer_trtllm_moe:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
return FlashInferEPMoE
if global_server_args_dict["enable_ep_moe"]: if global_server_args_dict["enable_ep_moe"]:
return EPMoE return EPMoE
return FusedMoE return FusedMoE
...@@ -75,7 +75,7 @@ class FusedMoE(torch.nn.Module): ...@@ -75,7 +75,7 @@ class FusedMoE(torch.nn.Module):
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
enable_flashinfer_moe: Optional[bool] = False, enable_flashinfer_cutlass_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False, enable_ep_moe: Optional[bool] = False,
): ):
super().__init__() super().__init__()
...@@ -92,16 +92,16 @@ class FusedMoE(torch.nn.Module): ...@@ -92,16 +92,16 @@ class FusedMoE(torch.nn.Module):
self.num_experts = num_experts self.num_experts = num_experts
self.expert_map = None self.expert_map = None
if enable_flashinfer_moe and quant_config is None: if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.") logger.warning("Disable flashinfer MoE when quantization config is None.")
enable_flashinfer_moe = False enable_flashinfer_cutlass_moe = False
enable_ep_moe = False enable_ep_moe = False
self.enable_flashinfer_moe = enable_flashinfer_moe self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
if enable_ep_moe: if enable_ep_moe:
assert ( assert (
self.enable_flashinfer_moe self.enable_flashinfer_cutlass_moe
), "FusedMoE only supports EP with --enable-flashinfer-moe" ), "FusedMoE only supports EP with --enable-flashinfer-cutlass-moe"
self.ep_size = self.tp_size self.ep_size = self.tp_size
self.ep_rank = self.tp_rank self.ep_rank = self.tp_rank
self.tp_size = 1 self.tp_size = 1
...@@ -141,7 +141,9 @@ class FusedMoE(torch.nn.Module): ...@@ -141,7 +141,9 @@ class FusedMoE(torch.nn.Module):
else: else:
self.quant_method = quant_config.get_quant_method(self, prefix) self.quant_method = quant_config.get_quant_method(self, prefix)
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod": if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe self.quant_method.enable_flashinfer_cutlass_moe = (
self.enable_flashinfer_cutlass_moe
)
assert self.quant_method is not None assert self.quant_method is not None
self.quant_config = quant_config self.quant_config = quant_config
......
...@@ -711,7 +711,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -711,7 +711,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" quantization. Please use Blackwell and" " quantization. Please use Blackwell and"
" above." " above."
) )
self.enable_flashinfer_moe = False self.enable_flashinfer_cutlass_moe = False
def create_weights( def create_weights(
self, self,
...@@ -865,7 +865,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -865,7 +865,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
if self.enable_flashinfer_moe: if self.enable_flashinfer_cutlass_moe:
w13_input_scale = layer.w13_input_scale.max().to(torch.float32) w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
else: else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
...@@ -894,7 +894,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -894,7 +894,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
# GEMM 2 # GEMM 2
if self.enable_flashinfer_moe: if self.enable_flashinfer_cutlass_moe:
w2_input_scale = layer.w2_input_scale.max().to(torch.float32) w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
else: else:
w2_input_scale = layer.w2_input_scale w2_input_scale = layer.w2_input_scale
...@@ -934,7 +934,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -934,7 +934,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
@property @property
def load_up_proj_weight_first(self) -> bool: def load_up_proj_weight_first(self) -> bool:
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
return self.enable_flashinfer_moe return self.enable_flashinfer_cutlass_moe
def apply( def apply(
self, self,
...@@ -954,7 +954,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -954,7 +954,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if self.enable_flashinfer_moe: if self.enable_flashinfer_cutlass_moe:
assert ( assert (
not apply_router_weight_on_input not apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer" ), "apply_router_weight_on_input is not supported for Flashinfer"
......
...@@ -88,7 +88,8 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -88,7 +88,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_deepep_moe", "enable_deepep_moe",
"deepep_mode", "deepep_mode",
"enable_ep_moe", "enable_ep_moe",
"enable_flashinfer_moe", "enable_flashinfer_cutlass_moe",
"enable_flashinfer_trtllm_moe",
"enable_flashinfer_allreduce_fusion", "enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size", "moe_dense_tp_size",
"ep_dispatch_algorithm", "ep_dispatch_algorithm",
......
...@@ -56,7 +56,11 @@ from sglang.srt.layers.linear import ( ...@@ -56,7 +56,11 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import (
DeepEPMoE,
get_moe_impl_class,
use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
...@@ -302,15 +306,19 @@ class DeepseekV2MoE(nn.Module): ...@@ -302,15 +306,19 @@ class DeepseekV2MoE(nn.Module):
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
) )
self.topk = TopK( self.topk = (
top_k=config.num_experts_per_tok + self.num_fused_shared_experts, TopK(
renormalize=config.norm_topk_prob, top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
use_grouped_topk=True, renormalize=config.norm_topk_prob,
num_expert_group=config.n_group, use_grouped_topk=True,
num_fused_shared_experts=self.num_fused_shared_experts, num_expert_group=config.n_group,
topk_group=config.topk_group, num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.gate.e_score_correction_bias, topk_group=config.topk_group,
routed_scaling_factor=self.routed_scaling_factor, correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
if not use_flashinfer_trtllm_moe
else None
) )
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class()(
...@@ -332,10 +340,22 @@ class DeepseekV2MoE(nn.Module): ...@@ -332,10 +340,22 @@ class DeepseekV2MoE(nn.Module):
# Additional args for FusedMoE # Additional args for FusedMoE
**( **(
dict( dict(
enable_flashinfer_moe=True, enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"], enable_ep_moe=global_server_args_dict["enable_ep_moe"],
) )
if global_server_args_dict["enable_flashinfer_moe"] if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
**(
dict(
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
if use_flashinfer_trtllm_moe
else {} else {}
), ),
) )
...@@ -455,10 +475,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -455,10 +475,12 @@ class DeepseekV2MoE(nn.Module):
with torch.cuda.stream(self.alt_stream): with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits) kwargs = {"hidden_states": hidden_states}
final_hidden_states = self.experts( if self.topk is not None:
hidden_states=hidden_states, topk_output=topk_output kwargs["topk_output"] = self.topk(hidden_states, router_logits)
) else:
kwargs["router_logits"] = router_logits
final_hidden_states = self.experts(**kwargs)
if not _is_cuda: if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream) current_stream.wait_stream(self.alt_stream)
...@@ -478,10 +500,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -478,10 +500,12 @@ class DeepseekV2MoE(nn.Module):
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits) kwargs = {"hidden_states": hidden_states}
final_hidden_states = self.experts( if self.topk is not None:
hidden_states=hidden_states, topk_output=topk_output kwargs["topk_output"] = self.topk(hidden_states, router_logits)
) else:
kwargs["router_logits"] = router_logits
final_hidden_states = self.experts(**kwargs)
if not _is_cuda and not _use_aiter: if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here # fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
......
...@@ -147,10 +147,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -147,10 +147,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
# Additional args for FusedMoE # Additional args for FusedMoE
**( **(
dict( dict(
enable_flashinfer_moe=True, enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"], enable_ep_moe=global_server_args_dict["enable_ep_moe"],
) )
if global_server_args_dict["enable_flashinfer_moe"] if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {} else {}
), ),
) )
......
...@@ -120,10 +120,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -120,10 +120,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# Additional args for FusedMoE # Additional args for FusedMoE
**( **(
dict( dict(
enable_flashinfer_moe=True, enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"], enable_ep_moe=global_server_args_dict["enable_ep_moe"],
) )
if global_server_args_dict["enable_flashinfer_moe"] if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {} else {}
), ),
) )
......
...@@ -169,7 +169,8 @@ class ServerArgs: ...@@ -169,7 +169,8 @@ class ServerArgs:
ep_size: int = 1 ep_size: int = 1
enable_ep_moe: bool = False enable_ep_moe: bool = False
enable_deepep_moe: bool = False enable_deepep_moe: bool = False
enable_flashinfer_moe: bool = False enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
enable_flashinfer_allreduce_fusion: bool = False enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
ep_num_redundant_experts: int = 0 ep_num_redundant_experts: int = 0
...@@ -428,12 +429,16 @@ class ServerArgs: ...@@ -428,12 +429,16 @@ class ServerArgs:
), "Please enable dp attention when setting enable_dp_lm_head. " ), "Please enable dp attention when setting enable_dp_lm_head. "
# MoE kernel # MoE kernel
if self.enable_flashinfer_moe: if self.enable_flashinfer_cutlass_moe:
assert ( assert (
self.quantization == "modelopt_fp4" self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE" ), "modelopt_fp4 quantization is required for Flashinfer MOE"
os.environ["TRTLLM_ENABLE_PDL"] = "1" os.environ["TRTLLM_ENABLE_PDL"] = "1"
if self.enable_flashinfer_trtllm_moe:
assert self.enable_ep_moe, "EP MoE is required for Flashinfer TRTLLM MOE"
logger.warning(f"Flashinfer TRTLLM MoE is enabled.")
# DeepEP MoE # DeepEP MoE
if self.enable_deepep_moe: if self.enable_deepep_moe:
if self.deepep_mode == "normal": if self.deepep_mode == "normal":
...@@ -1293,10 +1298,15 @@ class ServerArgs: ...@@ -1293,10 +1298,15 @@ class ServerArgs:
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.", help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
) )
parser.add_argument( parser.add_argument(
"--enable-flashinfer-moe", "--enable-flashinfer-cutlass-moe",
action="store_true", action="store_true",
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe", help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
) )
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe",
)
parser.add_argument( parser.add_argument(
"--enable-flashinfer-allreduce-fusion", "--enable-flashinfer-allreduce-fusion",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment