Unverified Commit af8fd730 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor][14/N] Clean Up FI Quant Config Smuggling (#31593)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent d3e477c0
...@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts ...@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, apply_flashinfer_per_tensor_scale_fp8,
flashinfer_cutlass_moe_fp8, flashinfer_cutlass_moe_fp8,
register_moe_scaling_factors, register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights, rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31, swap_w13_to_w31,
) )
...@@ -85,7 +85,7 @@ class TestData: ...@@ -85,7 +85,7 @@ class TestData:
@staticmethod @staticmethod
def make_moe_tensors_8bit( def make_moe_tensors_8bit(
m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu" m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu"
) -> "TestData": ) -> "TestData":
is_gated = activation != "relu2_no_mul" is_gated = activation != "relu2_no_mul"
...@@ -123,12 +123,17 @@ class TestData: ...@@ -123,12 +123,17 @@ class TestData:
all2all_backend="naive", all2all_backend="naive",
) )
register_moe_scaling_factors(layer)
# flashinfer expects swapped rows for w13 # flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if reorder: if is_trtllm:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
layer.w13_weight_scale,
layer.w13_input_scale,
layer.w2_weight_scale,
layer.w2_input_scale,
)
layer.custom_routing_function = Llama4MoE.custom_routing_function layer.custom_routing_function = Llama4MoE.custom_routing_function
layer.intermediate_size_per_partition = n layer.intermediate_size_per_partition = n
layer.ep_rank = 0 layer.ep_rank = 0
...@@ -162,7 +167,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -162,7 +167,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
set_random_seed(7) set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True) td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = Llama4MoE.custom_routing_function( topk_weights, topk_ids = Llama4MoE.custom_routing_function(
...@@ -227,7 +232,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ...@@ -227,7 +232,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit( td = TestData.make_moe_tensors_8bit(
m, k, n, e, reorder=False, activation=activation m, k, n, e, is_trtllm=False, activation=activation
) )
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
......
...@@ -452,11 +452,14 @@ class FusedMoEQuantConfig: ...@@ -452,11 +452,14 @@ class FusedMoEQuantConfig:
- a1_scale: Optional scale to be used for a1. - a1_scale: Optional scale to be used for a1.
- a2_scale: Optional scale to be used for a2. - a2_scale: Optional scale to be used for a2.
- g1_alphas: Optional global quantization scales for w1 (for nvfp4). - g1_alphas: Optional global quantization scales for w1 (for nvfp4).
per-channel scales for w1 (for W4A8 FP8). Optional per-channel scales for w1 (for W4A8 FP8).
Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8).
- g2_alphas: Optional global quantization scales for w2 (for nvfp4). - g2_alphas: Optional global quantization scales for w2 (for nvfp4).
per-channel scales for w2 (for W4A8 FP8). Optional per-channel scales for w2 (for W4A8 FP8).
- a1_gscale: Optional global quantization scales for a1 (for nvfp4). Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8).
- a2_gscale: Optional global quantization scales for a2 (for nvfp4). - a1_gscale: Optional global quantization scales for a1 (1.0 /a2_scale).
- a2_gscale: Optional global quantization scales for a2 (1.0 /a2_scale).
- w1_bias: Optional biases for w1 (GPT OSS Triton). - w1_bias: Optional biases for w1 (GPT OSS Triton).
- w2_bias: Optional biases for w1 (GPT OSS Triton). - w2_bias: Optional biases for w1 (GPT OSS Triton).
- w1_zp: Optional w1 zero points for int4/int8 quantization. - w1_zp: Optional w1 zero points for int4/int8 quantization.
......
...@@ -165,10 +165,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -165,10 +165,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
): ):
# FP8 per-tensor path: use global alphas/scales; do not pass input_sf # FP8 per-tensor path: use global alphas/scales; do not pass input_sf
quant_scales = [ quant_scales = [
self.g1_alphas, self.g1_alphas, # w13_weight_scale * w13_input_scale
self.a2_gscale, self.a2_gscale, # 1.0 / w2_input_scale
self.g2_alphas, self.g2_alphas, # w2_weight_scale * w2_input_scale
self.a1_gscale, self.a1_scale,
] ]
a1q_scale = None # not passing input_sf in fp8 a1q_scale = None # not passing input_sf in fp8
......
...@@ -184,13 +184,14 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin ...@@ -184,13 +184,14 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
self._apply_router_weight_on_input( self._apply_router_weight_on_input(
a1, topk_weights, topk_ids, apply_router_weight_on_input a1, topk_weights, topk_ids, apply_router_weight_on_input
) )
if not self.use_dp and quant_config.quant_dtype == "nvfp4": is_nvfp4 = quant_config.quant_dtype == "nvfp4"
if not self.use_dp and is_nvfp4:
return a1, None, None, topk_ids, topk_weights return a1, None, None, topk_ids, topk_weights
if not self.use_deepseek_fp8_block_scale: if not self.use_deepseek_fp8_block_scale:
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
quant_config.a1_gscale, quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale,
quant_config.quant_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.per_act_token_quant,
quant_config.block_shape, quant_config.block_shape,
...@@ -222,7 +223,7 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin ...@@ -222,7 +223,7 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
topk_weights, topk_ids, a1q = gathered topk_weights, topk_ids, a1q = gathered
a1q_scale = None a1q_scale = None
if quant_config.quant_dtype == "nvfp4" and a1q_scale is not None: if is_nvfp4 and a1q_scale is not None:
a1q_scale = nvfp4_block_scale_interleave(a1q_scale) a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights return a1q, a1q_scale, None, topk_ids, topk_weights
......
...@@ -50,7 +50,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( ...@@ -50,7 +50,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, apply_flashinfer_per_tensor_scale_fp8,
build_flashinfer_fp8_cutlass_moe_prepare_finalize, build_flashinfer_fp8_cutlass_moe_prepare_finalize,
get_flashinfer_moe_backend, get_flashinfer_moe_backend,
register_moe_scaling_factors, make_fp8_moe_alpha_scales_for_fi,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights, rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl, select_cutlass_fp8_gemm_impl,
swap_w13_to_w31, swap_w13_to_w31,
...@@ -774,6 +775,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -774,6 +775,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU " "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
"activation function, but got {layer.activation}." "activation function, but got {layer.activation}."
) )
dynamic_per_token = (
not self.block_quant and self.quant_config.activation_scheme != "static"
)
if self.flashinfer_moe_backend is not None and dynamic_per_token:
raise NotImplementedError(
"FlashInfer FP8 MoE backend does not support dynamic per token "
"activation quantization."
)
def create_weights( def create_weights(
self, self,
...@@ -905,6 +914,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -905,6 +914,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight: torch.Tensor, w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor, w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor, w2_weight_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> None: ) -> None:
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
assert self.block_quant assert self.block_quant
...@@ -949,11 +960,16 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -949,11 +960,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.block_quant: if self.block_quant:
w13_weight_scale = swap_w13_to_w31(w13_weight_scale) w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
else: else:
# TODO(rob): this function is a hack that renames the scaling
# factors in the Module. This is a hack we should clean up.
register_moe_scaling_factors(layer)
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer=layer,
w13_weight_scale=w13_weight,
w13_input_scale=w13_input_scale,
w2_weight_scale=w2_weight,
w2_input_scale=w2_input_scale,
)
elif self.fp8_backend == Fp8MoeBackend.AITER: elif self.fp8_backend == Fp8MoeBackend.AITER:
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights( w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight w13_weight, w2_weight
...@@ -990,6 +1006,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -990,6 +1006,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
AiterExperts, AiterExperts,
) )
# Flashinfer TRTLLM does not use the modular kernel abstraction.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
self.use_inplace = True self.use_inplace = True
...@@ -1087,7 +1107,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1087,7 +1107,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Shuffle weights into the runtime format. # Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format( self._convert_weights_to_kernel_format(
layer, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale layer=layer,
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
) )
# Setup modular kernel for TP case. # Setup modular kernel for TP case.
...@@ -1182,6 +1208,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1182,6 +1208,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
# TRTLLM does not use Modular Kernel.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
# MARLIN uses mixed precision W8A16 config.
if self.fp8_backend == Fp8MoeBackend.MARLIN: if self.fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config( return fp8_w8a16_moe_quant_config(
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"), w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
...@@ -1189,11 +1220,38 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1189,11 +1220,38 @@ class Fp8MoEMethod(FusedMoEMethodBase):
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
) )
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
# Flashinfer CUTLASS per-tensor uses single dq scale
# (alpha = w_scale * a_scale) and inverse a2 scale.
if (
self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS
and not self.block_quant
):
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w1_scale,
a1_scale,
w2_scale,
a2_scale,
)
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=(1.0 / a2_scale),
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
)
# All other backends use normal config.
return fp8_w8a8_moe_quant_config( return fp8_w8a8_moe_quant_config(
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"), w1_scale=w1_scale,
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"), w2_scale=w2_scale,
a1_scale=layer.w13_input_scale, a1_scale=a1_scale,
a2_scale=layer.w2_input_scale, a2_scale=a2_scale,
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
) )
...@@ -1414,7 +1472,13 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1414,7 +1472,13 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
# Shuffle weights into the runtime format. # Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format( self._convert_weights_to_kernel_format(
layer, w13_weight, w2_weight, layer.w13_weight_scale, layer.w2_weight_scale layer=layer,
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=layer.w13_weight_scale,
w2_weight_scale=layer.w2_weight_scale,
w13_input_scale=None,
w2_input_scale=None,
) )
# Setup modular kernel for TP case. # Setup modular kernel for TP case.
......
...@@ -50,7 +50,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( ...@@ -50,7 +50,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
flashinfer_cutlass_moe_fp8, flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend, get_flashinfer_moe_backend,
is_flashinfer_supporting_global_sf, is_flashinfer_supporting_global_sf,
register_moe_scaling_factors, make_fp8_moe_alpha_scales_for_fi,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights, rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl, select_cutlass_fp8_gemm_impl,
swap_w13_to_w31, swap_w13_to_w31,
...@@ -947,9 +948,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -947,9 +948,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
if self.flashinfer_moe_backend is not None: if self.flashinfer_moe_backend is not None:
if self.moe.is_act_and_mul: if self.moe.is_act_and_mul:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
# NOTE: this adds some attributes used by the trtllm kernel,
# which does not conform to the modular kernels abstraction (yet).
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
register_moe_scaling_factors(layer) register_scales_for_trtllm_fp8_per_tensor_moe(
layer=layer,
w13_weight_scale=layer.w13_weight_scale,
w13_input_scale=layer.w13_input_scale,
w2_weight_scale=layer.w2_weight_scale,
w2_input_scale=layer.w2_input_scale,
)
def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None: def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold. """Pad intermediate size so FlashInfer kernels' alignment constraints hold.
...@@ -999,19 +1009,34 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -999,19 +1009,34 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
# TRTLLM does not use modular kernels
return None return None
return fp8_w8a8_moe_quant_config( elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
w1_scale=layer.w13_weight_scale, g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
g1_alphas=layer.output1_scales_gate_scalar.squeeze(), layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, layer.w13_input_scale,
g2_alphas=layer.output2_scales_scalar.squeeze(), layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, layer.w2_input_scale,
a1_gscale=layer.w13_input_scale, )
a2_scale=layer.w2_input_scale, return fp8_w8a8_moe_quant_config(
a2_gscale=layer.w2_input_scale_inv, w1_scale=layer.w13_weight_scale,
per_act_token_quant=False, w2_scale=layer.w2_weight_scale,
) a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
a1_gscale=(1.0 / layer.w13_input_scale),
a2_gscale=(1.0 / layer.w2_input_scale),
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
)
else:
assert self.flashinfer_moe_backend is None
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
def apply( def apply(
self, self,
......
...@@ -103,6 +103,26 @@ def rotate_flashinfer_fp8_moe_weights( ...@@ -103,6 +103,26 @@ def rotate_flashinfer_fp8_moe_weights(
) )
def register_scales_for_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
w13_weight_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> None:
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w13_scale=w13_weight_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_weight_scale,
w2_input_scale=w2_input_scale,
)
layer.w2_input_scale_inv = 1.0 / w2_input_scale
layer.output1_scales_gate_scalar = g1_alphas
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
layer.output2_scales_scalar = g2_alphas
def apply_flashinfer_per_tensor_scale_fp8( def apply_flashinfer_per_tensor_scale_fp8(
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -117,18 +137,13 @@ def apply_flashinfer_per_tensor_scale_fp8( ...@@ -117,18 +137,13 @@ def apply_flashinfer_per_tensor_scale_fp8(
from flashinfer.fused_moe import RoutingMethodType from flashinfer.fused_moe import RoutingMethodType
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
from vllm.model_executor.models.llama4 import Llama4MoE
assert layer.output1_scales_scalar is not None, ( assert (
"Expected output1_scales_scalar to be initialized" hasattr(layer, "output1_scales_scalar")
) and hasattr(layer, "output1_scales_gate_scalar")
assert layer.output1_scales_scalar is not None, ( and hasattr(layer, "output2_scales_scalar")
"Expected output1_scales_gate_scalar to be initialized"
) )
assert layer.output1_scales_scalar is not None, (
"Expected output2_scales_scalar to be initialized"
)
from vllm.model_executor.models.llama4 import Llama4MoE
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, ( assert layer.custom_routing_function == Llama4MoE.custom_routing_function, (
"FusedMoE flashinfer kernels are only supported for Llama4" "FusedMoE flashinfer kernels are only supported for Llama4"
...@@ -155,40 +170,16 @@ def apply_flashinfer_per_tensor_scale_fp8( ...@@ -155,40 +170,16 @@ def apply_flashinfer_per_tensor_scale_fp8(
) )
def get_moe_scaling_factors( def make_fp8_moe_alpha_scales_for_fi(
input_scale: torch.Tensor, w13_scale: torch.Tensor,
gemm1_weights_scale: torch.Tensor, w13_input_scale: torch.Tensor,
activation_scale: torch.Tensor, w2_scale: torch.Tensor,
gemm2_weights_scale: torch.Tensor, w2_input_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
output1_scales_scalar = gemm1_weights_scale * input_scale * (1.0 / activation_scale) g1_alphas = (w13_scale * w13_input_scale).squeeze()
output1_scales_gate_scalar = gemm1_weights_scale * input_scale g2_alphas = (w2_scale * w2_input_scale).squeeze()
output2_scales_scalar = activation_scale * gemm2_weights_scale
return output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar
return g1_alphas, g2_alphas
def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
output1_scales, output1_gate_scales, output2_scales = get_moe_scaling_factors(
layer.w13_input_scale,
layer.w13_weight_scale,
layer.w2_input_scale,
layer.w2_weight_scale,
)
layer.register_parameter(
"output1_scales_scalar", torch.nn.Parameter(output1_scales, requires_grad=False)
)
layer.register_parameter(
"output1_scales_gate_scalar",
torch.nn.Parameter(output1_gate_scales, requires_grad=False),
)
layer.register_parameter(
"output2_scales_scalar", torch.nn.Parameter(output2_scales, requires_grad=False)
)
layer.register_parameter(
"w2_input_scale_inv",
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False),
)
def build_flashinfer_fp8_cutlass_moe_prepare_finalize( def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment