Unverified Commit 1bc183c6 authored by Alex Yang's avatar Alex Yang Committed by GitHub
Browse files

Faster weight processing (trtllm-gen moe nvfp4) (#9162)

parent b87aacb5
......@@ -737,6 +737,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" above."
)
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
self._cache_permute_indices = {}
@property
def enable_flashinfer_cutlass_moe(self) -> bool:
......@@ -900,10 +901,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
e2m1_and_ufp8sf_scale_to_float,
fp4_quantize,
next_positive_power_of_2,
nvfp4_block_scale_interleave,
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
from flashinfer.fused_moe.core import (
_maybe_get_cached_w2_permute_indices,
_maybe_get_cached_w3_w1_permute_indices,
)
"""Prepare quantized weights for kernel (done offline with weights)."""
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
......@@ -927,50 +933,66 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
num_experts, hidden_size, intermediate_size // 16
) # fp8 scaling factors
# Reorder rows of W1 and scales for fused gated activation
gemm1_weights_fp4_interleaved = []
gemm1_scales_fp4_interleaved = []
for i in range(num_experts):
gemm1_weights_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
)
gemm1_scales_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
)
# Stack weights and scales for all experts
gemm1_weights_fp4_interleaved = torch.stack(
gemm1_weights_fp4_interleaved
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
gemm1_scales_fp4_interleaved = torch.stack(
gemm1_scales_fp4_interleaved
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_fp4_shuffled = []
gemm1_scales_fp4_shuffled = []
gemm2_weights_fp4_shuffled = []
gemm2_scales_fp4_shuffled = []
for i in range(num_experts):
# Calculate the permute indices for the following:
# 1. Reorder rows of W1 and scales for fused gated activation
# 2. Shuffle weights and scaling factors for transposed mma output
# for both w3_w1 and w2 weights and scale factors
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
self._cache_permute_indices,
gemm1_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_fp4_shuffled.append(
shuffle_matrix_a(
gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
)
gemm1_weights_fp4[i]
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
.contiguous()
)
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
self._cache_permute_indices,
gemm1_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_fp4_shuffled.append(
shuffle_matrix_sf_a(
gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
nvfp4_block_scale_interleave(
gemm1_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
]
.contiguous()
)
)
permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
gemm2_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_fp4_shuffled.append(
shuffle_matrix_a(
gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
)
gemm2_weights_fp4[i]
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
.contiguous()
)
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
gemm2_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_fp4_shuffled.append(
shuffle_matrix_sf_a(
gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
nvfp4_block_scale_interleave(
gemm2_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
]
.contiguous()
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment