Unverified Commit aa4c66b5 authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

[NVIDIA] Enable Flashinfer MoE blockscale fp8 backend for TP MoE (#8450)


Co-authored-by: default avatarkushanam <42385577+kushanam@users.noreply.github.com>
parent 39decec1
......@@ -25,14 +25,22 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_triton_kernel,
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FlashInferFusedMoE,
FusedMoE,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8 import (
Fp8Config,
Fp8MoEMethod,
get_tile_tokens_dim,
)
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
sglang_per_token_group_quant_fp8,
......@@ -49,7 +57,6 @@ from sglang.srt.utils import (
get_bool_env_var,
is_hip,
is_npu,
next_power_of_2,
)
if TYPE_CHECKING:
......@@ -63,10 +70,7 @@ _is_hip = is_hip()
_is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
use_flashinfer_trtllm_moe = (
global_server_args_dict["enable_flashinfer_trtllm_moe"]
and global_server_args_dict["enable_ep_moe"]
)
if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul
......@@ -76,26 +80,9 @@ if _use_aiter:
from aiter.fused_moe import fused_moe
from aiter.ops.shuffle import shuffle_weight
if use_flashinfer_trtllm_moe:
try:
import flashinfer.fused_moe as fi_fused_moe
except ImportError:
fi_fused_moe = None
use_flashinfer_trtllm_moe = False
logger = logging.getLogger(__name__)
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
class EPMoE(FusedMoE):
"""
MoE Expert Parallel Impl
......@@ -731,10 +718,10 @@ class FlashInferEPMoE(EPMoE):
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.use_flashinfer_trtllm_moe = use_flashinfer_trtllm_moe
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert use_flashinfer_trtllm_moe
assert self.use_flashinfer_trtllm_moe
assert (
self.activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe"
......@@ -747,8 +734,9 @@ class FlashInferEPMoE(EPMoE):
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
assert fi_fused_moe is not None
return fi_fused_moe.trtllm_fp8_block_scale_moe(
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
routing_bias=self.correction_bias.to(hidden_states.dtype),
hidden_states=a_q,
......@@ -765,7 +753,7 @@ class FlashInferEPMoE(EPMoE):
local_expert_offset=self.start_expert_id,
local_num_experts=self.num_local_experts,
routed_scaling_factor=self.routed_scaling_factor,
tile_tokens_dim=_get_tile_tokens_dim(
tile_tokens_dim=get_tile_tokens_dim(
hidden_states.shape[0], self.top_k, self.num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
......@@ -779,9 +767,6 @@ def get_moe_impl_class():
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
return FusedMoE
if use_flashinfer_trtllm_moe:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
return FlashInferEPMoE
if global_server_args_dict["enable_ep_moe"]:
return EPMoE
return FusedMoE
return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import importlib.util
import logging
from enum import Enum
from functools import lru_cache
from typing import List, Optional, Tuple
import torch
from packaging import version as pkg_version
from sglang.srt.distributed import (
get_moe_expert_parallel_rank,
......@@ -33,6 +36,15 @@ _is_cpu = is_cpu()
logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def should_use_flashinfer_trtllm_moe():
return global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
......@@ -455,7 +467,7 @@ class FusedMoE(torch.nn.Module):
)
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if getattr(self, "use_flashinfer_trtllm_moe", False):
if should_use_flashinfer_trtllm_moe():
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
......@@ -687,3 +699,44 @@ class FusedMoE(torch.nn.Module):
for expert_id in range(num_experts)
for shard_id in ["w1", "w2", "w3"]
]
class FlashInferFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs):
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
super().__init__(*args, **kwargs)
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None
assert (
self.renormalize
), "Renormalize is required for flashinfer blockscale fp8 moe"
assert (
self.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
# Matrix multiply.
final_hidden_states = self.quant_method.apply_with_router_logits(
layer=self,
x=hidden_states,
router_logits=router_logits,
activation=self.activation,
routed_scaling_factor=self.routed_scaling_factor,
)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
......@@ -72,6 +72,7 @@ from sglang.srt.utils import (
is_hip,
is_npu,
log_info_on_rank0,
next_power_of_2,
print_warning_once,
set_weight_attrs,
use_intel_amx_backend,
......@@ -490,6 +491,16 @@ class Fp8LinearMethod(LinearMethodBase):
)
def get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
......@@ -1076,6 +1087,47 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routed_scaling_factor=routed_scaling_factor,
)
def apply_with_router_logits(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
*,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
assert (
activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe"
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
routing_bias=layer.correction_bias.to(x.dtype),
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=layer.w13_weight,
gemm1_weights_scale=layer.w13_weight_scale_inv,
gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale_inv,
num_experts=layer.num_experts,
top_k=layer.top_k,
n_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=routed_scaling_factor,
tile_tokens_dim=get_tile_tokens_dim(
x.shape[0], layer.top_k, layer.num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
def maybe_apply_hip_fused_experts(
self,
layer: torch.nn.Module,
......
......@@ -59,7 +59,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import (
DeepEPMoE,
get_moe_impl_class,
use_flashinfer_trtllm_moe,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import TopK
......@@ -317,7 +317,7 @@ class DeepseekV2MoE(nn.Module):
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
if not use_flashinfer_trtllm_moe
if not should_use_flashinfer_trtllm_moe()
else None
)
......@@ -352,11 +352,10 @@ class DeepseekV2MoE(nn.Module):
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
if use_flashinfer_trtllm_moe
if should_use_flashinfer_trtllm_moe()
else {}
),
)
......
......@@ -52,7 +52,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import (
DeepEPMoE,
get_moe_impl_class,
use_flashinfer_trtllm_moe,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
......@@ -426,7 +426,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
if not use_flashinfer_trtllm_moe
if not should_use_flashinfer_trtllm_moe()
else None
)
......@@ -465,7 +465,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
if use_flashinfer_trtllm_moe
if should_use_flashinfer_trtllm_moe()
else {}
),
)
......
......@@ -460,10 +460,6 @@ class ServerArgs:
f"Flashinfer cutlass MoE and EP MoE are enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
if self.enable_flashinfer_trtllm_moe:
assert self.enable_ep_moe, "EP MoE is required for Flashinfer TRTLLM MOE"
logger.warning(f"Flashinfer TRTLLM MoE is enabled.")
# DeepEP MoE
if self.enable_deepep_moe:
if self.deepep_mode == "normal":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment