Unverified Commit f120bd42 authored by amitz-nv's avatar amitz-nv Committed by GitHub
Browse files

[Kernel] Support Flashinfer trtllm fused MoE non gated FP8 & NVFP4 (#33506)


Signed-off-by: default avataramitz-nv <203509407+amitz-nv@users.noreply.github.com>
parent fac4e969
...@@ -71,7 +71,8 @@ def quant_fp8_per_tensor_batches(a): ...@@ -71,7 +71,8 @@ def quant_fp8_per_tensor_batches(a):
for i in range(num_batches): for i in range(num_batches):
a_fp8, a_global_sf = input_to_float8(a[i]) a_fp8, a_global_sf = input_to_float8(a[i])
a_global_sf = 1.0 / a_global_sf if a_global_sf.numel() == 1:
a_global_sf = a_global_sf.view(1, 1)
a_quant.append(a_fp8) a_quant.append(a_fp8)
a_scales.append(a_global_sf) a_scales.append(a_global_sf)
...@@ -81,6 +82,20 @@ def quant_fp8_per_tensor_batches(a): ...@@ -81,6 +82,20 @@ def quant_fp8_per_tensor_batches(a):
return result_a_quant, result_a_scales return result_a_quant, result_a_scales
def check_accuracy(ref_output, actual_output, atol=0.1, rtol=0.85, percent=0.925):
close = torch.isclose(ref_output, actual_output, atol=atol, rtol=rtol)
match_ratio = close.float().mean()
assert match_ratio >= percent, (
f"Match ratio {match_ratio:.4f} is below the threshold {percent:.4f}"
)
mismatch_percent = 1.0 - match_ratio.item()
assert mismatch_percent <= 1 - percent, (
f"Mismatch percentage {mismatch_percent:.4f} is above the threshold "
f"{1 - percent:.4f}"
)
@dataclass @dataclass
class TestData: class TestData:
hidden_states: torch.Tensor hidden_states: torch.Tensor
...@@ -104,14 +119,16 @@ class TestData: ...@@ -104,14 +119,16 @@ class TestData:
is_gated = activation.is_gated is_gated = activation.is_gated
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
w13 = torch.randn( w13 = (
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16 torch.randn(
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
)
/ 10
) )
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
# Scale to fp8 # Scale to fp8
_, a1_scale = input_to_float8(hidden_states) _, a1_scale = input_to_float8(hidden_states)
a1_scale = 1.0 / a1_scale
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32) a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13) w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2) w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
...@@ -124,14 +141,16 @@ class TestData: ...@@ -124,14 +141,16 @@ class TestData:
layer.w2_input_scale = a2_scale layer.w2_input_scale = a2_scale
layer.w13_weight_scale = w13_weight_scale layer.w13_weight_scale = w13_weight_scale
layer.w2_weight_scale = w2_weight_scale layer.w2_weight_scale = w2_weight_scale
layer.activation = activation
# Setup dummy config. # Setup dummy config.
layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel() layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
# flashinfer expects swapped rows for w13 # flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) if is_gated:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if is_trtllm: if is_trtllm:
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
layer.w13_weight, layer.w2_weight layer.w13_weight, layer.w2_weight, is_gated
) )
register_scales_for_trtllm_fp8_per_tensor_moe( register_scales_for_trtllm_fp8_per_tensor_moe(
layer, layer,
...@@ -162,12 +181,14 @@ class TestData: ...@@ -162,12 +181,14 @@ class TestData:
@pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
def test_flashinfer_per_tensor_moe_fp8_no_graph( def test_flashinfer_per_tensor_moe_fp8_no_graph(
m: int, m: int,
n: int, n: int,
k: int, k: int,
e: int, e: int,
topk: int, topk: int,
activation: MoEActivation,
monkeypatch, monkeypatch,
): ):
if not current_platform.has_device_capability(100): if not current_platform.has_device_capability(100):
...@@ -175,7 +196,9 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -175,7 +196,9 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
set_random_seed(7) set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True) td = TestData.make_moe_tensors_8bit(
m, k, n, e, is_trtllm=True, activation=activation
)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = Llama4MoE.custom_routing_function( topk_weights, topk_ids = Llama4MoE.custom_routing_function(
...@@ -200,7 +223,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -200,7 +223,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=False, inplace=False,
activation=MoEActivation.SILU, activation=activation,
global_num_experts=e, global_num_experts=e,
expert_map=None, expert_map=None,
apply_router_weight_on_input=True, apply_router_weight_on_input=True,
...@@ -219,7 +242,13 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -219,7 +242,13 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
apply_router_weight_on_input=True, apply_router_weight_on_input=True,
) )
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2) check_accuracy(
ref_output=output,
actual_output=flashinfer_output,
atol=0.1,
rtol=0.85,
percent=0.925,
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
...@@ -320,8 +349,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -320,8 +349,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
expert_map=None, expert_map=None,
apply_router_weight_on_input=True, apply_router_weight_on_input=True,
) )
torch.testing.assert_close(
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2 check_accuracy(
ref_output=output,
actual_output=flashinfer_cutlass_output,
atol=0.1,
rtol=0.85,
percent=0.925,
) )
......
...@@ -35,8 +35,8 @@ def _supports_current_device() -> bool: ...@@ -35,8 +35,8 @@ def _supports_current_device() -> bool:
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:
"""Does not support non-gated MoE (i.e. Nanotron-Mini).""" """Supports non-gated MoE."""
return False return True
def _supports_quant_scheme( def _supports_quant_scheme(
...@@ -52,8 +52,7 @@ def _supports_quant_scheme( ...@@ -52,8 +52,7 @@ def _supports_quant_scheme(
def _supports_activation(activation: MoEActivation) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
"""Supports silu activation only.""" return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
return activation == MoEActivation.SILU
def _supports_routing_method( def _supports_routing_method(
...@@ -74,6 +73,7 @@ def _supports_routing_method( ...@@ -74,6 +73,7 @@ def _supports_routing_method(
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here. # NOTE(dbari): as above, potentially allow others here.
return routing_method in [ return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4, RoutingMethodType.Llama4,
RoutingMethodType.Renormalize, RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive, RoutingMethodType.RenormalizeNaive,
...@@ -291,6 +291,7 @@ def fi_trtllm_fp8_per_tensor_moe( ...@@ -291,6 +291,7 @@ def fi_trtllm_fp8_per_tensor_moe(
local_num_experts: int, local_num_experts: int,
use_routing_scales_on_input: bool, use_routing_scales_on_input: bool,
routing_method_type: int, routing_method_type: int,
activation_type: int,
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
num_expert_group = num_expert_group if num_expert_group is not None else 0 num_expert_group = num_expert_group if num_expert_group is not None else 0
...@@ -326,9 +327,9 @@ def fi_trtllm_fp8_per_tensor_moe( ...@@ -326,9 +327,9 @@ def fi_trtllm_fp8_per_tensor_moe(
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input, use_routing_scales_on_input=use_routing_scales_on_input,
routing_method_type=routing_method_type, routing_method_type=routing_method_type,
# TODO: Required for flashinfer==0.6.3, remove with update # TODO: enum type Required for flashinfer==0.6.3, remove with update
# https://github.com/flashinfer-ai/flashinfer/pull/2508 # https://github.com/flashinfer-ai/flashinfer/pull/2508
activation_type=ActivationType.Swiglu, activation_type=ActivationType(activation_type),
) )
...@@ -351,6 +352,7 @@ def fi_trtllm_fp8_per_tensor_moe_fake( ...@@ -351,6 +352,7 @@ def fi_trtllm_fp8_per_tensor_moe_fake(
local_num_experts: int, local_num_experts: int,
use_routing_scales_on_input: bool, use_routing_scales_on_input: bool,
routing_method_type: int, routing_method_type: int,
activation_type: int,
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
......
...@@ -937,10 +937,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -937,10 +937,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
) )
# TODO(rob): this validation should happen at kernel selection # TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here. # time in the oracle rather than here.
assert layer.activation == MoEActivation.SILU, ( SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
f"Expected 'silu' activation but got {layer.activation}" assert layer.activation in SUPPORTED_ACTIVATIONS, (
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
f"TRTLLM FP4 MoE, {layer.activation} found instead."
) )
assert not layer.renormalize
return apply_fi_trtllm_fp8_per_tensor_moe( return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer, layer=layer,
hidden_states=x, hidden_states=x,
......
...@@ -15,6 +15,10 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -15,6 +15,10 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig, FusedMoEParallelConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
align_fp4_moe_weights_for_fi,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
swizzle_blockscale, swizzle_blockscale,
) )
...@@ -50,8 +54,8 @@ def _supports_current_device() -> bool: ...@@ -50,8 +54,8 @@ def _supports_current_device() -> bool:
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:
"""Does not support non-gated MoE (i.e. Nemotron-Nano).""" """Supports non-gated MoE."""
return False return True
def _supports_quant_scheme( def _supports_quant_scheme(
...@@ -66,8 +70,7 @@ def _supports_quant_scheme( ...@@ -66,8 +70,7 @@ def _supports_quant_scheme(
def _supports_activation(activation: MoEActivation) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
"""Supports silu activation only.""" return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
return activation in [MoEActivation.SILU]
def _supports_routing_method( def _supports_routing_method(
...@@ -150,6 +153,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( ...@@ -150,6 +153,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
hidden_size, hidden_size,
intermediate_size, intermediate_size,
num_experts, num_experts,
is_gated_activation: bool,
): ):
from flashinfer import nvfp4_block_scale_interleave from flashinfer import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import ( from flashinfer.fused_moe.core import (
...@@ -160,15 +164,18 @@ def prepare_static_weights_for_trtllm_fp4_moe( ...@@ -160,15 +164,18 @@ def prepare_static_weights_for_trtllm_fp4_moe(
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {} _cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
"""Prepare quantized weights for kernel (done offline with weights).""" """Prepare quantized weights for kernel (done offline with weights)."""
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
gemm1_intermediate_size = (
2 * intermediate_size if is_gated_activation else intermediate_size
)
# Convert quantized weights to proper formats # Convert quantized weights to proper formats
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape( gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2 num_experts, gemm1_intermediate_size, hidden_size // 2
) # packed fp4 ) # packed fp4
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn torch.float8_e4m3fn
).reshape( ).reshape(
num_experts, 2 * intermediate_size, hidden_size // 16 num_experts, gemm1_intermediate_size, hidden_size // 16
) # fp8 scaling factors ) # fp8 scaling factors
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape( gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
...@@ -191,6 +198,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( ...@@ -191,6 +198,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
_cache_permute_indices, _cache_permute_indices,
gemm1_weights_fp4[i].view(torch.uint8), gemm1_weights_fp4[i].view(torch.uint8),
epilogue_tile_m, epilogue_tile_m,
is_gated_act_gemm=is_gated_activation,
) )
gemm1_weights_fp4_shuffled.append( gemm1_weights_fp4_shuffled.append(
gemm1_weights_fp4[i] gemm1_weights_fp4[i]
...@@ -203,6 +211,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( ...@@ -203,6 +211,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
gemm1_scales_linear_fp4[i].view(torch.uint8), gemm1_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m, epilogue_tile_m,
num_elts_per_sf=16, num_elts_per_sf=16,
is_gated_act_gemm=is_gated_activation,
) )
gemm1_scales_fp4_shuffled.append( gemm1_scales_fp4_shuffled.append(
nvfp4_block_scale_interleave( nvfp4_block_scale_interleave(
...@@ -246,7 +255,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( ...@@ -246,7 +255,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
gemm1_scales_fp4_shuffled = ( gemm1_scales_fp4_shuffled = (
torch.stack(gemm1_scales_fp4_shuffled) torch.stack(gemm1_scales_fp4_shuffled)
.view(torch.float8_e4m3fn) .view(torch.float8_e4m3fn)
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16) .reshape(num_experts, gemm1_intermediate_size, hidden_size // 16)
) )
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled) gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
...@@ -297,10 +306,10 @@ def flashinfer_trtllm_fp4_moe( ...@@ -297,10 +306,10 @@ def flashinfer_trtllm_fp4_moe(
from vllm.model_executor.models.llama4 import Llama4MoE from vllm.model_executor.models.llama4 import Llama4MoE
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404 SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert activation == MoEActivation.SILU, ( assert activation in SUPPORTED_ACTIVATIONS, (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. " f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
f"{activation} found instead." f"TRTLLM FP4 MoE, {activation} found instead."
) )
# Quantize input to FP4 # Quantize input to FP4
...@@ -325,6 +334,9 @@ def flashinfer_trtllm_fp4_moe( ...@@ -325,6 +334,9 @@ def flashinfer_trtllm_fp4_moe(
else router_logits else router_logits
) )
# Determine activation type
activation_type = activation_to_flashinfer_int(layer.activation)
# Call TRT-LLM FP4 block-scale MoE kernel # Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits, routing_logits=router_logits,
...@@ -355,6 +367,7 @@ def flashinfer_trtllm_fp4_moe( ...@@ -355,6 +367,7 @@ def flashinfer_trtllm_fp4_moe(
routed_scaling_factor=None, routed_scaling_factor=None,
routing_method_type=routing_method_type, routing_method_type=routing_method_type,
do_finalize=True, do_finalize=True,
activation_type=activation_type,
)[0] )[0]
return out return out
...@@ -479,10 +492,16 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( ...@@ -479,10 +492,16 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
] ]
# Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels. # Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels.
if is_act_and_mul and backend in [ is_gated = layer.activation.is_gated
NvFp4MoeBackend.FLASHINFER_CUTLASS, if (
NvFp4MoeBackend.FLASHINFER_TRTLLM, is_gated
]: and is_act_and_mul
and backend
in [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
]
):
w13, w13_scale = reorder_w1w3_to_w3w1(w13, w13_scale) w13, w13_scale = reorder_w1w3_to_w3w1(w13, w13_scale)
# For some FI kernels, the input scales are shared by all experts. # For some FI kernels, the input scales are shared by all experts.
...@@ -495,19 +514,32 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( ...@@ -495,19 +514,32 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
# Shuffle weights and scales for FI TRTLLM NVFP4 MoE kernels. # Shuffle weights and scales for FI TRTLLM NVFP4 MoE kernels.
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
# Align weights for FI NVFP4 MoE kernels.
min_alignment = 16 if is_gated else 128
w13, w13_scale, w2, w2_scale, padded_intermediate = (
align_fp4_moe_weights_for_fi(
w13, w13_scale, w2, w2_scale, is_act_and_mul, min_alignment
)
)
layer.intermediate_size_per_partition = padded_intermediate
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe( w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
w13, w13,
w2, w2,
w13_scale, w13_scale,
w2_scale, w2_scale,
w2.size(-2), # hidden_size hidden_size=w2.size(-2),
w13.size(-2) // 2, # intermediate_size intermediate_size=w13.size(-2) // 2 if is_gated else w13.size(-2),
w13.size(0), # num_experts num_experts=w13.size(0),
is_gated_activation=is_gated,
) )
# We do not need to make this a parameter, because # We do not need to make this a parameter, because
# it is not used during the weight (re)-loading process. # it is not used during the weight (re)-loading process.
layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale if is_gated:
layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale
else:
layer.g1_scale_c = torch.ones_like(a13_scale) / a2_scale
layer.a1_gscale = 1.0 / a13_scale layer.a1_gscale = 1.0 / a13_scale
layer.g1_alphas = a13_scale * w13_scale_2 layer.g1_alphas = a13_scale * w13_scale_2
layer.g2_alphas = a2_scale * w2_scale_2 layer.g2_alphas = a2_scale * w2_scale_2
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
...@@ -18,6 +19,20 @@ class FlashinferMoeBackend(Enum): ...@@ -18,6 +19,20 @@ class FlashinferMoeBackend(Enum):
CUTEDSL = "CUTEDSL" CUTEDSL = "CUTEDSL"
def activation_to_flashinfer_int(activation: MoEActivation) -> int:
from flashinfer.fused_moe.core import ActivationType
# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
ACTIVATION_TO_FI_ACTIVATION = {
MoEActivation.SILU_NO_MUL: ActivationType.Silu,
MoEActivation.GELU_NO_MUL: ActivationType.Gelu,
MoEActivation.SILU: ActivationType.Swiglu,
MoEActivation.GELU: ActivationType.Geglu,
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
}
return ACTIVATION_TO_FI_ACTIVATION[activation].value
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
return ( return (
x.reshape(-1, 2, x.shape[-2] // 2, x.shape[-1]).flip(dims=[1]).reshape(x.shape) x.reshape(-1, 2, x.shape[-2] // 2, x.shape[-1]).flip(dims=[1]).reshape(x.shape)
...@@ -25,7 +40,7 @@ def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: ...@@ -25,7 +40,7 @@ def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor, is_gated_activation: bool
): ):
"""Shuffle weights for for FI TRT-LLM Format""" """Shuffle weights for for FI TRT-LLM Format"""
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
...@@ -40,6 +55,8 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( ...@@ -40,6 +55,8 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
for i in range(num_experts): for i in range(num_experts):
gemm1_weights_fp8_interleaved.append( gemm1_weights_fp8_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_weights[i]) reorder_rows_for_gated_act_gemm(gemm1_weights[i])
if is_gated_activation
else gemm1_weights[i]
) )
# Stack weights and scales for all experts # Stack weights and scales for all experts
...@@ -86,7 +103,13 @@ def register_scales_for_trtllm_fp8_per_tensor_moe( ...@@ -86,7 +103,13 @@ def register_scales_for_trtllm_fp8_per_tensor_moe(
) )
layer.w2_input_scale_inv = 1.0 / w2_input_scale layer.w2_input_scale_inv = 1.0 / w2_input_scale
layer.output1_scales_gate_scalar = g1_alphas layer.output1_scales_gate_scalar = g1_alphas
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
if layer.activation.is_gated:
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
else:
layer.output1_scales_scalar = (
torch.ones_like(g1_alphas) * layer.w2_input_scale_inv
)
layer.output2_scales_scalar = g2_alphas layer.output2_scales_scalar = g2_alphas
...@@ -125,6 +148,7 @@ def apply_fi_trtllm_fp8_per_tensor_moe( ...@@ -125,6 +148,7 @@ def apply_fi_trtllm_fp8_per_tensor_moe(
assert layer.custom_routing_function is None, ( assert layer.custom_routing_function is None, (
"Custom routing function is only supported for Llama4" "Custom routing function is only supported for Llama4"
) )
activation_type = activation_to_flashinfer_int(layer.activation)
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe( return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
routing_logits=router_logits, routing_logits=router_logits,
...@@ -145,6 +169,7 @@ def apply_fi_trtllm_fp8_per_tensor_moe( ...@@ -145,6 +169,7 @@ def apply_fi_trtllm_fp8_per_tensor_moe(
local_num_experts=layer.local_num_experts, local_num_experts=layer.local_num_experts,
use_routing_scales_on_input=apply_router_weight_on_input, use_routing_scales_on_input=apply_router_weight_on_input,
routing_method_type=layer.routing_method_type, routing_method_type=layer.routing_method_type,
activation_type=activation_type,
) )
...@@ -274,8 +299,64 @@ def convert_moe_weights_to_flashinfer_trtllm_block_layout( ...@@ -274,8 +299,64 @@ def convert_moe_weights_to_flashinfer_trtllm_block_layout(
return w13_weights_shuffled_tensor, w2_weights_shuffled_tensor return w13_weights_shuffled_tensor, w2_weights_shuffled_tensor
def align_fp4_moe_weights_for_fi(
w13: torch.Tensor,
w13_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
is_act_and_mul: bool,
min_alignment: int = 16,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
Some FlashInfer FP4 MoE kernels require the intermediate size
used for GEMM to be divisible by a small alignment value. When this is
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
gate/up and down projection weights along the intermediate dim.
"""
# Current local intermediate size (per partition) is the K dimension of
# the down projection.
num_experts, hidden_size, intermediate = w2.shape
intermediate *= 2 # because of packed FP4
padded_intermediate = round_up(intermediate, min_alignment)
if padded_intermediate == intermediate:
return w13, w13_scale, w2, w2_scale, intermediate
logger.info_once(
"Padding intermediate size from %d to %d for up/down projection weights.",
intermediate,
padded_intermediate,
scope="local",
)
up_mult = 2 if is_act_and_mul else 1
padded_gate_up_dim = up_mult * padded_intermediate
# Pad w13 and w2 along its intermediate dimension.
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size // 2))
padded_w13[:, : w13.shape[1], :] = w13
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate // 2))
padded_w2[:, :, : w2.shape[2]] = w2
padded_w13_scale = w13_scale.new_zeros(
(num_experts, padded_gate_up_dim, hidden_size // 16)
)
padded_w13_scale[:, : w13_scale.shape[1], :] = w13_scale
padded_w2_scale = w2_scale.new_zeros(
(num_experts, hidden_size, padded_intermediate // 16)
)
padded_w2_scale[:, :, : w2_scale.shape[2]] = w2_scale
return padded_w13, padded_w13_scale, padded_w2, padded_w2_scale, padded_intermediate
def align_fp8_moe_weights_for_fi( def align_fp8_moe_weights_for_fi(
w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool, min_alignment: int = 16
) -> tuple[torch.Tensor, torch.Tensor, int]: ) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold. """Pad intermediate size so FlashInfer kernels' alignment constraints hold.
...@@ -289,7 +370,6 @@ def align_fp8_moe_weights_for_fi( ...@@ -289,7 +370,6 @@ def align_fp8_moe_weights_for_fi(
# the down projection. # the down projection.
num_experts, hidden_size, intermediate = w2.shape num_experts, hidden_size, intermediate = w2.shape
min_alignment = 16
padded_intermediate = round_up(intermediate, min_alignment) padded_intermediate = round_up(intermediate, min_alignment)
if padded_intermediate == intermediate: if padded_intermediate == intermediate:
...@@ -342,11 +422,14 @@ def prepare_fp8_moe_layer_for_fi( ...@@ -342,11 +422,14 @@ def prepare_fp8_moe_layer_for_fi(
# Some FI MoE kernels require internal alignment of 16 # Some FI MoE kernels require internal alignment of 16
# for the gate-up proj. Pad the weights to respect this. # for the gate-up proj. Pad the weights to respect this.
is_gated = layer.activation.is_gated
if not block_quant: if not block_quant:
min_alignment = 16 if is_gated else 128
w13, w2, new_intermediate = align_fp8_moe_weights_for_fi( w13, w2, new_intermediate = align_fp8_moe_weights_for_fi(
w13, w13,
w2, w2,
layer.moe_config.is_act_and_mul, layer.moe_config.is_act_and_mul,
min_alignment,
) )
layer.intermediate_size_per_partition = new_intermediate layer.intermediate_size_per_partition = new_intermediate
...@@ -363,7 +446,7 @@ def prepare_fp8_moe_layer_for_fi( ...@@ -363,7 +446,7 @@ def prepare_fp8_moe_layer_for_fi(
assert w13_input_scale is not None assert w13_input_scale is not None
assert w2_input_scale is not None assert w2_input_scale is not None
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2) rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated)
register_scales_for_trtllm_fp8_per_tensor_moe( register_scales_for_trtllm_fp8_per_tensor_moe(
layer, layer,
w13_scale=w13_scale, w13_scale=w13_scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment