Unverified Commit 1bc183c6 authored by Alex Yang's avatar Alex Yang Committed by GitHub
Browse files

Faster weight processing (trtllm-gen moe nvfp4) (#9162)

parent b87aacb5
...@@ -737,6 +737,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -737,6 +737,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" above." " above."
) )
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe() self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
self._cache_permute_indices = {}
@property @property
def enable_flashinfer_cutlass_moe(self) -> bool: def enable_flashinfer_cutlass_moe(self) -> bool:
...@@ -900,10 +901,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -900,10 +901,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
e2m1_and_ufp8sf_scale_to_float, e2m1_and_ufp8sf_scale_to_float,
fp4_quantize, fp4_quantize,
next_positive_power_of_2, next_positive_power_of_2,
nvfp4_block_scale_interleave,
reorder_rows_for_gated_act_gemm, reorder_rows_for_gated_act_gemm,
shuffle_matrix_a, shuffle_matrix_a,
shuffle_matrix_sf_a, shuffle_matrix_sf_a,
) )
from flashinfer.fused_moe.core import (
_maybe_get_cached_w2_permute_indices,
_maybe_get_cached_w3_w1_permute_indices,
)
"""Prepare quantized weights for kernel (done offline with weights).""" """Prepare quantized weights for kernel (done offline with weights)."""
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
...@@ -927,50 +933,66 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -927,50 +933,66 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
num_experts, hidden_size, intermediate_size // 16 num_experts, hidden_size, intermediate_size // 16
) # fp8 scaling factors ) # fp8 scaling factors
# Reorder rows of W1 and scales for fused gated activation
gemm1_weights_fp4_interleaved = []
gemm1_scales_fp4_interleaved = []
for i in range(num_experts):
gemm1_weights_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
)
gemm1_scales_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
)
# Stack weights and scales for all experts
gemm1_weights_fp4_interleaved = torch.stack(
gemm1_weights_fp4_interleaved
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
gemm1_scales_fp4_interleaved = torch.stack(
gemm1_scales_fp4_interleaved
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_fp4_shuffled = [] gemm1_weights_fp4_shuffled = []
gemm1_scales_fp4_shuffled = [] gemm1_scales_fp4_shuffled = []
gemm2_weights_fp4_shuffled = [] gemm2_weights_fp4_shuffled = []
gemm2_scales_fp4_shuffled = [] gemm2_scales_fp4_shuffled = []
for i in range(num_experts): for i in range(num_experts):
# Calculate the permute indices for the following:
# 1. Reorder rows of W1 and scales for fused gated activation
# 2. Shuffle weights and scaling factors for transposed mma output
# for both w3_w1 and w2 weights and scale factors
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
self._cache_permute_indices,
gemm1_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_fp4_shuffled.append( gemm1_weights_fp4_shuffled.append(
shuffle_matrix_a( gemm1_weights_fp4[i]
gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
.contiguous()
) )
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
self._cache_permute_indices,
gemm1_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
) )
gemm1_scales_fp4_shuffled.append( gemm1_scales_fp4_shuffled.append(
shuffle_matrix_sf_a( nvfp4_block_scale_interleave(
gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m gemm1_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
]
.contiguous()
) )
) )
permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
gemm2_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_fp4_shuffled.append( gemm2_weights_fp4_shuffled.append(
shuffle_matrix_a( gemm2_weights_fp4[i]
gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
.contiguous()
) )
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
gemm2_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
) )
gemm2_scales_fp4_shuffled.append( gemm2_scales_fp4_shuffled.append(
shuffle_matrix_sf_a( nvfp4_block_scale_interleave(
gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m gemm2_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
]
.contiguous()
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment