"vllm/vscode:/vscode.git/clone" did not exist on "c530e2cfe3b3d7e60130ff817cee7f3a395af232"
Unverified Commit d816834c authored by Jaewon's avatar Jaewon Committed by GitHub
Browse files

[MoE] Add RoutingMethodType.Simulated to TRT-LLM FP8/NVFP4 kernel allowlists (#38329)


Signed-off-by: default avatarJaewon Lee <jaewon@meta.com>
parent 92f0db57
......@@ -256,13 +256,18 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
) -> bool:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
DeepSeekV3 routing supports float32 router_logits (converted internally).
Simulated routing generates synthetic decisions and is agnostic to dtype.
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# DeepSeekV3 routing handles float32 logits internally.
# Simulated routing generates synthetic decisions, so the
# kernel doesn't care about the actual logits dtype.
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return routing_method in (
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Simulated,
)
return True
@staticmethod
......@@ -288,12 +293,14 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Simulated,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.Simulated,
]
else:
raise ValueError("Unsupported quantization scheme.")
......
......@@ -255,6 +255,7 @@ class TrtLlmNvFp4ExpertsMonolithic(
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.Llama4,
RoutingMethodType.Simulated,
]
@staticmethod
......@@ -264,13 +265,18 @@ class TrtLlmNvFp4ExpertsMonolithic(
) -> bool:
"""
The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
DeepSeekV3 routing supports float32 router_logits (converted internally).
Simulated routing generates synthetic decisions and is agnostic to dtype.
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# DeepSeekV3 routing handles float32 logits internally.
# Simulated routing generates synthetic decisions, so the
# kernel doesn't care about the actual logits dtype.
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return routing_method in (
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Simulated,
)
return True
def apply(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment