Commit 7d98f09b authored by Michael Goin's avatar Michael Goin Committed by Robert Shaw
Browse files

cherry pick


Signed-off-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
parent daa2784b
...@@ -86,7 +86,23 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo ...@@ -86,7 +86,23 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
return not moe_parallel_config.enable_eplb return not moe_parallel_config.enable_eplb
def is_supported_config_trtllm( def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return True
def is_supported_config_trtllm_fp8(
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
...@@ -115,13 +131,17 @@ def is_supported_config_trtllm( ...@@ -115,13 +131,17 @@ def is_supported_config_trtllm(
return False, _make_reason("routing method") return False, _make_reason("routing method")
elif activation_format != mk.FusedMoEActivationFormat.Standard: elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason("activation format") return False, _make_reason("activation format")
elif not _supports_router_logits_dtype(
moe_config.router_logits_dtype, moe_config.routing_method
):
return False, _make_reason("float32 router_logits with non-DeepSeekV3 routing")
return True, None return True, None
def flashinfer_fused_moe_blockscale_fp8( def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor, routing_logits: torch.Tensor,
routing_bias: torch.Tensor, routing_bias: torch.Tensor | None,
x: torch.Tensor, x: torch.Tensor,
w13_weight: torch.Tensor, w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor, w13_weight_scale_inv: torch.Tensor,
...@@ -135,7 +155,7 @@ def flashinfer_fused_moe_blockscale_fp8( ...@@ -135,7 +155,7 @@ def flashinfer_fused_moe_blockscale_fp8(
expert_offset: int, expert_offset: int,
local_num_experts: int, local_num_experts: int,
block_shape: list[int], block_shape: list[int],
routing_method_type: int = int(RoutingMethodType.DeepSeekV3), routing_method_type: int,
routed_scaling: float | None = 1.0, routed_scaling: float | None = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
...@@ -148,6 +168,13 @@ def flashinfer_fused_moe_blockscale_fp8( ...@@ -148,6 +168,13 @@ def flashinfer_fused_moe_blockscale_fp8(
# Routing kernel expects #experts <= #threads 512 # Routing kernel expects #experts <= #threads 512
assert global_num_experts <= 512 assert global_num_experts <= 512
# The DeepSeekV3 routing method requires float32 router logits.
if routing_method_type == RoutingMethodType.DeepSeekV3:
routing_logits = routing_logits.to(torch.float32)
if routing_bias is not None:
routing_bias = routing_bias.to(x.dtype)
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed! # NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous() a_sf_t = a_sf.t().contiguous()
...@@ -175,7 +202,7 @@ def flashinfer_fused_moe_blockscale_fp8( ...@@ -175,7 +202,7 @@ def flashinfer_fused_moe_blockscale_fp8(
def flashinfer_fused_moe_blockscale_fp8_fake( def flashinfer_fused_moe_blockscale_fp8_fake(
routing_logits: torch.Tensor, routing_logits: torch.Tensor,
routing_bias: torch.Tensor, routing_bias: torch.Tensor | None,
x: torch.Tensor, x: torch.Tensor,
w13_weight: torch.Tensor, w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor, w13_weight_scale_inv: torch.Tensor,
......
...@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a16_moe_quant_config, fp8_w8a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm, is_supported_config_trtllm_fp8,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoEP,
...@@ -212,7 +212,7 @@ def select_fp8_moe_backend( ...@@ -212,7 +212,7 @@ def select_fp8_moe_backend(
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend = Fp8MoeBackend.FLASHINFER_TRTLLM backend = Fp8MoeBackend.FLASHINFER_TRTLLM
supported, reason = is_supported_config_trtllm( supported, reason = is_supported_config_trtllm_fp8(
config, weight_key, activation_key, activation_format config, weight_key, activation_key, activation_format
) )
if supported: if supported:
...@@ -239,7 +239,7 @@ def select_fp8_moe_backend( ...@@ -239,7 +239,7 @@ def select_fp8_moe_backend(
]: ]:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
k_cls = None k_cls = None
supported, reason = is_supported_config_trtllm( supported, reason = is_supported_config_trtllm_fp8(
config, config,
weight_key, weight_key,
activation_key, activation_key,
...@@ -308,7 +308,7 @@ def select_fp8_moe_backend( ...@@ -308,7 +308,7 @@ def select_fp8_moe_backend(
for backend in AVAILABLE_BACKENDS: for backend in AVAILABLE_BACKENDS:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
k_cls = None k_cls = None
supported, reason = is_supported_config_trtllm( supported, reason = is_supported_config_trtllm_fp8(
config, config,
weight_key, weight_key,
activation_key, activation_key,
......
...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe import (
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType,
int4_w4a16_moe_quant_config, int4_w4a16_moe_quant_config,
int4_w4afp8_moe_quant_config, int4_w4afp8_moe_quant_config,
int8_w8a8_moe_quant_config, int8_w8a8_moe_quant_config,
...@@ -1072,17 +1071,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1072,17 +1071,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
if self.block_quant: if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = (
layer.e_score_correction_bias.to(x.dtype)
if layer.e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32) routing_logits=router_logits,
if routing_method_type == RoutingMethodType.DeepSeekV3 routing_bias=layer.e_score_correction_bias,
else router_logits,
routing_bias=e_score_correction_bias,
x=x, x=x,
w13_weight=layer.w13_weight, w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale, w13_weight_scale_inv=layer.w13_weight_scale,
...@@ -1096,7 +1087,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1096,7 +1087,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_offset=layer.ep_rank * layer.local_num_experts, expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts, local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
routing_method_type=routing_method_type, routing_method_type=layer.routing_method_type,
routed_scaling=layer.routed_scaling_factor, routed_scaling=layer.routed_scaling_factor,
) )
else: else:
......
...@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe import (
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
...@@ -990,17 +989,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -990,17 +989,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.block_quant: if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = (
layer.e_score_correction_bias.to(x.dtype)
if layer.e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32) routing_logits=router_logits,
if routing_method_type == RoutingMethodType.DeepSeekV3 routing_bias=layer.e_score_correction_bias,
else router_logits,
routing_bias=e_score_correction_bias,
x=x, x=x,
w13_weight=layer.w13_weight, w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv, w13_weight_scale_inv=layer.w13_weight_scale_inv,
...@@ -1014,7 +1005,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1014,7 +1005,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_offset=layer.ep_rank * layer.local_num_experts, expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts, local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
routing_method_type=routing_method_type, routing_method_type=layer.routing_method_type,
routed_scaling=layer.routed_scaling_factor, routed_scaling=layer.routed_scaling_factor,
) )
else: else:
......
...@@ -107,6 +107,7 @@ class MiniMaxM2MoE(nn.Module): ...@@ -107,6 +107,7 @@ class MiniMaxM2MoE(nn.Module):
renormalize=True, renormalize=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
router_logits_dtype=torch.float32,
) )
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment