Unverified Commit 4373df55 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

add flashinfer mxfp4 (#8847)

parent c0e84297
......@@ -38,6 +38,7 @@ from sglang.srt.utils import (
is_flashinfer_available,
is_hip,
next_power_of_2,
round_up,
)
if is_flashinfer_available():
......@@ -146,7 +147,6 @@ class FusedMoE(torch.nn.Module):
self.layer_id = layer_id
self.top_k = top_k
self.hidden_size = hidden_size
self.num_experts = num_experts
self.num_fused_shared_experts = num_fused_shared_experts
self.expert_map_cpu = None
......@@ -206,6 +206,16 @@ class FusedMoE(torch.nn.Module):
assert self.quant_method is not None
self.quant_config = quant_config
if (
self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
and (
get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE")
or get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE")
)
):
hidden_size = round_up(hidden_size, 256)
self.hidden_size = hidden_size
self.quant_method.create_weights(
layer=self,
num_experts=self.num_local_experts,
......@@ -784,6 +794,14 @@ class FusedMoE(torch.nn.Module):
)
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
origin_hidden_states_dim = hidden_states.shape[-1]
if self.hidden_size != origin_hidden_states_dim:
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, self.hidden_size - origin_hidden_states_dim),
mode="constant",
value=0.0,
)
assert self.quant_method is not None
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
......@@ -829,7 +847,7 @@ class FusedMoE(torch.nn.Module):
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
return final_hidden_states[..., :origin_hidden_states_dim].contiguous()
@classmethod
def make_expert_params_mapping(
......
......@@ -21,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
is_cuda,
is_flashinfer_available,
is_hip,
......@@ -31,6 +32,12 @@ from sglang.srt.utils import (
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
# Environment variables for FlashInfer MXFP4 MoE backend
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE", "false")
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
)
if is_flashinfer_available():
# from flashinfer.fused_moe import cutlass_fused_moe
from flashinfer import (
......@@ -228,16 +235,28 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.num_experts = num_experts
weight_dtype = torch.uint8
scale_dtype = torch.uint8
intermediate_size *= 2
mxfp4_block = 32
self.intermediate_size = intermediate_size
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256)
hidden_size = round_up(hidden_size, 256)
elif is_hip():
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 128)
else:
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 64)
self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
num_experts, 2 * intermediate_size, hidden_size // 2, dtype=weight_dtype
num_experts,
2 * intermediate_size_per_partition_after_pad,
hidden_size // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
......@@ -247,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size,
2 * intermediate_size_per_partition_after_pad,
hidden_size // mxfp4_block,
dtype=scale_dtype,
),
......@@ -257,7 +276,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w13_weight_bias = torch.nn.Parameter(
torch.zeros(num_experts, 2 * intermediate_size, dtype=torch.bfloat16),
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_bias", w13_weight_bias)
......@@ -266,7 +289,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.zeros(
num_experts, hidden_size, intermediate_size // 2, dtype=weight_dtype
num_experts,
hidden_size,
intermediate_size_per_partition_after_pad // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
......@@ -277,7 +303,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
torch.zeros(
num_experts,
hidden_size,
intermediate_size // mxfp4_block,
intermediate_size_per_partition_after_pad // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
......@@ -293,6 +319,158 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
def process_weights_after_loading(self, layer):
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
logger.info(
"Shuffling MoE weights for FlashInfer, it might take a while..."
)
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_beta = Parameter(
torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_clamp_limit = Parameter(
torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
sf_block_size = 32 # mxfp4 block size
assert (
layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2
)
assert (
layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
)
assert (
layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size
and layer.w2_weight.shape[2] == self.intermediate_size // 2
)
assert (
layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size
)
assert (
layer.w13_weight_bias.dim() == 2
and layer.w13_weight_bias.shape[0] == self.num_experts
and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2
)
assert (
layer.w2_weight_bias.dim() == 2
and layer.w2_weight_bias.shape[0] == self.num_experts
and layer.w2_weight_bias.shape[1] == self.hidden_size
)
w13_weight_scale = layer.w13_weight_scale.data
w2_weight_scale = layer.w2_weight_scale.data
w13_weight = layer.w13_weight.data
w2_weight = layer.w2_weight.data
w13_bias = layer.w13_weight_bias.data.to(torch.float32)
w2_bias = layer.w2_weight_bias.data.to(torch.float32)
# Swap w1 and w3 as the definition of
# swiglu is different in the trtllm-gen
def swap_every_two_rows(x, axis=-1):
shape = x.shape
if axis < 0:
axis = len(shape) + axis
# Create a new shape with pairs swapped along specified axis
new_shape = list(shape)
new_shape[axis] = shape[axis] // 2
new_shape.insert(axis + 1, 2)
# Reshape to expose pairs, swap them, and reshape back
x = x.reshape(*new_shape)
x = x.flip(axis + 1)
new_shape = list(shape)
return x.reshape(*new_shape)
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
w13_weight = swap_every_two_rows(w13_weight, -2)
w13_bias = swap_every_two_rows(w13_bias, -1)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_mxfp4_shuffled = []
gemm1_scales_mxfp4_shuffled = []
gemm2_weights_mxfp4_shuffled = []
gemm2_scales_mxfp4_shuffled = []
gemm1_bias_shuffled = []
gemm2_bias_shuffled = []
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
for i in range(self.num_experts):
gemm1_weights_mxfp4_shuffled.append(
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
)
gemm1_scales_mxfp4_shuffled.append(
shuffle_matrix_sf_a(
w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
)
)
gemm1_bias_shuffled.append(
shuffle_matrix_a(
w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m
)
)
gemm2_weights_mxfp4_shuffled.append(
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
)
gemm2_scales_mxfp4_shuffled.append(
shuffle_matrix_sf_a(
w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
)
)
gemm2_bias_shuffled.append(
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m)
)
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
w13_weight_scale = (
torch.stack(gemm1_scales_mxfp4_shuffled)
.reshape(
self.num_experts,
2 * self.intermediate_size,
self.hidden_size // sf_block_size,
)
.view(torch.float8_e4m3fn)
)
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
w2_weight_scale = (
torch.stack(gemm2_scales_mxfp4_shuffled)
.reshape(
self.num_experts,
self.hidden_size,
self.intermediate_size // sf_block_size,
)
.view(torch.float8_e4m3fn)
)
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
layer.w13_weight_bias = Parameter(
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
requires_grad=False,
)
layer.w2_weight_bias = Parameter(
torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
requires_grad=False,
)
return
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
......@@ -366,22 +544,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor:
# avoid import error when triton_kernel is not installed
# from vllm.model_executor.layers.fused_moe.triton_kernels_moe import (
# triton_kernel_moe_forward)
"""
if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE
or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE):
assert not self.moe.use_ep, (
"EP is not supported for flashinfer mxfp4 moe backend yet.")
if envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE:
if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE:
# When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
# which can theoretically improve performance
if USE_FLASHINFER_MXFP4_BF16_MOE:
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None
else:
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
topk_weights, topk_ids, router_logits = topk_output
top_k = topk_weights.shape[-1]
trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16),
None, # routing_bias
......@@ -412,7 +589,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
True, # do finalize
)[0]
return trtllm_gen_output
"""
if self.use_triton_kernels:
if self.with_bias:
......
......@@ -464,7 +464,21 @@ class ServerArgs:
model_arch = self.get_hf_config().architectures[0]
if model_arch in ["GptOssForCausalLM"]:
self.attention_backend = "triton"
self.enable_triton_kernel_moe = True
# Check if FlashInfer MXFP4 MoE is enabled
from sglang.srt.utils import get_bool_env_var
USE_FLASHINFER_MXFP4_MOE = get_bool_env_var(
"SGLANG_USE_FLASHINFER_MXFP4_MOE", "false"
)
USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false"
)
# Only enable Triton kernel MoE if FlashInfer is not enabled
if not (USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE):
self.enable_triton_kernel_moe = True
self.disable_hybrid_swa_memory = True
quantization_config = getattr(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment