Unverified Commit d816834c authored by Jaewon's avatar Jaewon Committed by GitHub
Browse files

[MoE] Add RoutingMethodType.Simulated to TRT-LLM FP8/NVFP4 kernel allowlists (#38329)


Signed-off-by: default avatarJaewon Lee <jaewon@meta.com>
parent 92f0db57
...@@ -256,13 +256,18 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit ...@@ -256,13 +256,18 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
) -> bool: ) -> bool:
""" """
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default. The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted DeepSeekV3 routing supports float32 router_logits (converted internally).
internally in the kernel). Simulated routing generates synthetic decisions and is agnostic to dtype.
""" """
if router_logits_dtype == torch.float32: if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits # DeepSeekV3 routing handles float32 logits internally.
# Simulated routing generates synthetic decisions, so the
# kernel doesn't care about the actual logits dtype.
# https://github.com/flashinfer-ai/flashinfer/issues/2469 # https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3 return routing_method in (
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Simulated,
)
return True return True
@staticmethod @staticmethod
...@@ -288,12 +293,14 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit ...@@ -288,12 +293,14 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
# NOTE(rob): potentially allow others here. This is a conservative list. # NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [ return routing_method in [
RoutingMethodType.DeepSeekV3, RoutingMethodType.DeepSeekV3,
RoutingMethodType.Simulated,
] ]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here. # NOTE(dbari): as above, potentially allow others here.
return routing_method in [ return routing_method in [
RoutingMethodType.DeepSeekV3, RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4, RoutingMethodType.Llama4,
RoutingMethodType.Simulated,
] ]
else: else:
raise ValueError("Unsupported quantization scheme.") raise ValueError("Unsupported quantization scheme.")
......
...@@ -255,6 +255,7 @@ class TrtLlmNvFp4ExpertsMonolithic( ...@@ -255,6 +255,7 @@ class TrtLlmNvFp4ExpertsMonolithic(
RoutingMethodType.Renormalize, RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive, RoutingMethodType.RenormalizeNaive,
RoutingMethodType.Llama4, RoutingMethodType.Llama4,
RoutingMethodType.Simulated,
] ]
@staticmethod @staticmethod
...@@ -264,13 +265,18 @@ class TrtLlmNvFp4ExpertsMonolithic( ...@@ -264,13 +265,18 @@ class TrtLlmNvFp4ExpertsMonolithic(
) -> bool: ) -> bool:
""" """
The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default. The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted DeepSeekV3 routing supports float32 router_logits (converted internally).
internally in the kernel). Simulated routing generates synthetic decisions and is agnostic to dtype.
""" """
if router_logits_dtype == torch.float32: if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits # DeepSeekV3 routing handles float32 logits internally.
# Simulated routing generates synthetic decisions, so the
# kernel doesn't care about the actual logits dtype.
# https://github.com/flashinfer-ai/flashinfer/issues/2469 # https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3 return routing_method in (
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Simulated,
)
return True return True
def apply( def apply(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment