Unverified Commit 9c53dad8 authored by Jue WANG's avatar Jue WANG Committed by GitHub
Browse files

Fix MTP MoE weight loading with NVFP4 target model. (#10758)

parent 7ca1bea6
...@@ -575,7 +575,10 @@ class FusedMoE(torch.nn.Module): ...@@ -575,7 +575,10 @@ class FusedMoE(torch.nn.Module):
) )
# Flashinfer assumes w31 format for w13_weight. Same for the scales. # Flashinfer assumes w31 format for w13_weight. Same for the scales.
if should_use_flashinfer_trtllm_moe(): if (
should_use_flashinfer_trtllm_moe()
and self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
):
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id] shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported] WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment