Unverified Commit 9e23ad96 authored by Shu Wang's avatar Shu Wang Committed by GitHub
Browse files

Update fp4 quantize API (#21327)


Signed-off-by: default avatarShu Wang <shuw@nvidia.com>
parent e69a92a1
...@@ -181,12 +181,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -181,12 +181,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
g2_alphas, g2_alphas,
] ]
_ = flashinfer_cutlass_fused_moe( _ = flashinfer_cutlass_fused_moe(
hidden_states, input=hidden_states,
topk_ids.to(torch.int), token_selected_experts=topk_ids.to(torch.int),
topk_weights, token_final_scales=topk_weights,
# FlashInfer API requires weight to be long for nvfp4 # FlashInfer API requires weight to be long for nvfp4
w1.view(torch.long), fc1_expert_weights=w1.view(torch.long),
w2.view(torch.long), fc2_expert_weights=w2.view(torch.long),
output_dtype=out_dtype, output_dtype=out_dtype,
quant_scales=quant_scales, quant_scales=quant_scales,
input_sf=a1q_scale, input_sf=a1q_scale,
......
...@@ -11,7 +11,7 @@ from vllm.forward_context import get_forward_context ...@@ -11,7 +11,7 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
extract_required_args, moe_kernel_quantize_input) extract_required_args, moe_kernel_quantize_input)
from vllm.utils.flashinfer import fp4_swizzle_blockscale from vllm.utils.flashinfer import block_scale_interleave
def get_local_sizes(local_tokens): def get_local_sizes(local_tokens):
...@@ -92,7 +92,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -92,7 +92,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dim=0, dim=0,
sizes=get_local_sizes(local_tokens)) sizes=get_local_sizes(local_tokens))
a1_m, a1_n = a1q.shape a1_m, a1_n = a1q.shape
a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2) a1q_scale = block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights return a1q, a1q_scale, None, topk_ids, topk_weights
......
...@@ -69,8 +69,8 @@ flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper( ...@@ -69,8 +69,8 @@ flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
"cutlass_fused_moe") "cutlass_fused_moe")
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
fp4_swizzle_blockscale = _lazy_import_wrapper("flashinfer", block_scale_interleave = _lazy_import_wrapper("flashinfer",
"fp4_swizzle_blockscale") "block_scale_interleave")
# Special case for autotune since it returns a context manager # Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper( autotune = _lazy_import_wrapper(
...@@ -95,7 +95,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool: ...@@ -95,7 +95,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
required_functions = [ required_functions = [
("flashinfer.fused_moe", "cutlass_fused_moe"), ("flashinfer.fused_moe", "cutlass_fused_moe"),
("flashinfer", "fp4_quantize"), ("flashinfer", "fp4_quantize"),
("flashinfer", "fp4_swizzle_blockscale"), ("flashinfer", "block_scale_interleave"),
] ]
for module_name, attr_name in required_functions: for module_name, attr_name in required_functions:
...@@ -110,7 +110,7 @@ __all__ = [ ...@@ -110,7 +110,7 @@ __all__ = [
"flashinfer_trtllm_fp8_block_scale_moe", "flashinfer_trtllm_fp8_block_scale_moe",
"flashinfer_cutlass_fused_moe", "flashinfer_cutlass_fused_moe",
"fp4_quantize", "fp4_quantize",
"fp4_swizzle_blockscale", "block_scale_interleave",
"autotune", "autotune",
"has_flashinfer_moe", "has_flashinfer_moe",
"has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutlass_fused_moe",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment