Unverified Commit 240f2636 authored by danielafrimi's avatar danielafrimi Committed by GitHub
Browse files

[Kernel] Support TRTLLM GEN NVFP4 MoE for non-512-aligned hidden dims via weight padding (#39510)


Signed-off-by: default avatarroot <root@lyris0017.lyris.clusters.nvidia.com>
Signed-off-by: default avatarDaniel Afrimi <dafrimi@nvidia.com>
Co-authored-by: default avatarroot <root@lyris0017.lyris.clusters.nvidia.com>
parent dc8df110
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
align_trtllm_fp4_moe_hidden_dim_for_fi,
)
def test_align_trtllm_fp4_moe_hidden_dim_noop():
w13 = torch.arange(2 * 8 * 256, dtype=torch.uint8).reshape(2, 8, 256)
w13_scale = torch.arange(2 * 8 * 32, dtype=torch.uint8).reshape(2, 8, 32)
w2 = torch.arange(2 * 512 * 4, dtype=torch.uint8).reshape(2, 512, 4)
w2_scale = torch.arange(2 * 512 * 1, dtype=torch.uint8).reshape(2, 512, 1)
out_w13, out_w13_scale, out_w2, out_w2_scale, padded_hidden = (
align_trtllm_fp4_moe_hidden_dim_for_fi(w13, w13_scale, w2, w2_scale)
)
assert padded_hidden == 512
assert out_w13 is w13
assert out_w13_scale is w13_scale
assert out_w2 is w2
assert out_w2_scale is w2_scale
def test_align_trtllm_fp4_moe_hidden_dim_pads_to_256_multiple():
hidden_dim = 2688
padded_hidden_dim = 2816
w13 = torch.arange(2 * 12 * (hidden_dim // 2), dtype=torch.uint8).reshape(
2, 12, hidden_dim // 2
)
w13_scale = torch.arange(2 * 12 * (hidden_dim // 16), dtype=torch.uint8).reshape(
2, 12, hidden_dim // 16
)
w2 = torch.arange(2 * hidden_dim * 6, dtype=torch.uint8).reshape(2, hidden_dim, 6)
w2_scale = torch.arange(2 * hidden_dim * 2, dtype=torch.uint8).reshape(
2, hidden_dim, 2
)
out_w13, out_w13_scale, out_w2, out_w2_scale, out_hidden_dim = (
align_trtllm_fp4_moe_hidden_dim_for_fi(w13, w13_scale, w2, w2_scale)
)
assert out_hidden_dim == padded_hidden_dim
assert out_w13.shape == (2, 12, padded_hidden_dim // 2)
assert out_w13_scale.shape == (2, 12, padded_hidden_dim // 16)
assert out_w2.shape == (2, padded_hidden_dim, 6)
assert out_w2_scale.shape == (2, padded_hidden_dim, 2)
torch.testing.assert_close(out_w13[:, :, : hidden_dim // 2], w13)
torch.testing.assert_close(out_w13_scale[:, :, : hidden_dim // 16], w13_scale)
torch.testing.assert_close(out_w2[:, :hidden_dim, :], w2)
torch.testing.assert_close(out_w2_scale[:, :hidden_dim, :], w2_scale)
assert torch.count_nonzero(out_w13[:, :, hidden_dim // 2 :]) == 0
assert torch.count_nonzero(out_w13_scale[:, :, hidden_dim // 16 :]) == 0
assert torch.count_nonzero(out_w2[:, hidden_dim:, :]) == 0
assert torch.count_nonzero(out_w2_scale[:, hidden_dim:, :]) == 0
...@@ -50,6 +50,9 @@ class TrtLlmNvFp4ExpertsBase: ...@@ -50,6 +50,9 @@ class TrtLlmNvFp4ExpertsBase:
moe_config.intermediate_size_per_partition moe_config.intermediate_size_per_partition
) )
self.hidden_dim = moe_config.hidden_dim self.hidden_dim = moe_config.hidden_dim
self.hidden_dim_unpadded = (
moe_config.hidden_dim_unpadded or moe_config.hidden_dim
)
self.local_num_experts = moe_config.num_local_experts self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank self.ep_rank = moe_config.moe_parallel_config.ep_rank
...@@ -114,8 +117,12 @@ class TrtLlmNvFp4ExpertsBase: ...@@ -114,8 +117,12 @@ class TrtLlmNvFp4ExpertsBase:
@staticmethod @staticmethod
def _supports_shape(hidden_dim: int) -> bool: def _supports_shape(hidden_dim: int) -> bool:
"""Requires hidden dim to be multiple of 512.""" # Weights are zero-padded to 256-alignment at load time and the MoE
return hidden_dim % 512 == 0 # runner pads activations via _maybe_pad_hidden_states, so any
# hidden_dim is accepted.
# NOTE: non-256-aligned dims will trigger a warning log and may
# cause performance degradation due to activation slicing.
return True
@staticmethod @staticmethod
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
...@@ -194,7 +201,7 @@ class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModula ...@@ -194,7 +201,7 @@ class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModula
import vllm.utils.flashinfer as fi_utils import vllm.utils.flashinfer as fi_utils
if fi_utils._is_fi_autotuning: if fi_utils._is_fi_autotuning:
return hidden_states return
# Invoke kernel. # Invoke kernel.
flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
...@@ -324,6 +331,8 @@ class TrtLlmNvFp4ExpertsMonolithic( ...@@ -324,6 +331,8 @@ class TrtLlmNvFp4ExpertsMonolithic(
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16) e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)
# Invoke kernel. # Invoke kernel.
# NOTE: Activation padding and output
# truncation are handled by the MoE runner's
return flashinfer.fused_moe.trtllm_fp4_block_scale_moe( return flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits, routing_logits=router_logits,
routing_bias=e_score_correction_bias, routing_bias=e_score_correction_bias,
......
...@@ -10,6 +10,7 @@ import vllm.envs as envs ...@@ -10,6 +10,7 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
align_fp4_moe_weights_for_fi, align_fp4_moe_weights_for_fi,
align_trtllm_fp4_moe_hidden_dim_for_fi,
) )
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
swizzle_blockscale, swizzle_blockscale,
...@@ -341,6 +342,13 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( ...@@ -341,6 +342,13 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
# Shuffle weights and scales for FI TRTLLM NVFP4 MoE kernels. # Shuffle weights and scales for FI TRTLLM NVFP4 MoE kernels.
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
w13, w13_scale, w2, w2_scale, padded_hidden = (
align_trtllm_fp4_moe_hidden_dim_for_fi(w13, w13_scale, w2, w2_scale)
)
if layer.moe_config.hidden_dim_unpadded is None:
layer.moe_config.hidden_dim_unpadded = layer.moe_config.hidden_dim
layer.moe_config.hidden_dim = padded_hidden
# Align weights for FI NVFP4 MoE kernels. # Align weights for FI NVFP4 MoE kernels.
min_alignment = 16 if is_gated else 128 min_alignment = 16 if is_gated else 128
w13, w13_scale, w2, w2_scale, padded_intermediate = ( w13, w13_scale, w2, w2_scale, padded_intermediate = (
......
...@@ -265,6 +265,48 @@ def align_fp4_moe_weights_for_fi( ...@@ -265,6 +265,48 @@ def align_fp4_moe_weights_for_fi(
return padded_w13, padded_w13_scale, padded_w2, padded_w2_scale, padded_intermediate return padded_w13, padded_w13_scale, padded_w2, padded_w2_scale, padded_intermediate
def align_trtllm_fp4_moe_hidden_dim_for_fi(
w13: torch.Tensor,
w13_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
min_alignment: int = 256,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
num_experts, gate_up_dim, packed_hidden_size = w13.shape
hidden_size = packed_hidden_size * 2
padded_hidden_size = round_up(hidden_size, min_alignment)
if padded_hidden_size == hidden_size:
return w13, w13_scale, w2, w2_scale, hidden_size
logger.warning_once(
"Padding hidden size from %d to %d for TRTLLM NVFP4 MoE weights. "
"This requires activation slicing at runtime and may cause "
"performance degradation.",
hidden_size,
padded_hidden_size,
scope="local",
)
padded_w13 = w13.new_zeros((num_experts, gate_up_dim, padded_hidden_size // 2))
padded_w13[:, :, :packed_hidden_size] = w13
padded_w13_scale = w13_scale.new_zeros(
(num_experts, gate_up_dim, padded_hidden_size // 16)
)
padded_w13_scale[:, :, : w13_scale.shape[2]] = w13_scale
padded_w2 = w2.new_zeros((num_experts, padded_hidden_size, w2.shape[2]))
padded_w2[:, : w2.shape[1], :] = w2
padded_w2_scale = w2_scale.new_zeros(
(num_experts, padded_hidden_size, w2_scale.shape[2])
)
padded_w2_scale[:, : w2_scale.shape[1], :] = w2_scale
return padded_w13, padded_w13_scale, padded_w2, padded_w2_scale, padded_hidden_size
def align_fp8_moe_weights_for_fi( def align_fp8_moe_weights_for_fi(
w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool, min_alignment: int = 16 w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool, min_alignment: int = 16
) -> tuple[torch.Tensor, torch.Tensor, int]: ) -> tuple[torch.Tensor, torch.Tensor, int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment