@@ -1048,6 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1048,6 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.problem_sizes1,
self.problem_sizes1,
self.problem_sizes2,
self.problem_sizes2,
use_fp8_blockscale=True,
use_fp8_blockscale=True,
output=symm_output,
)
)
returnStandardCombineInput(hidden_states=output)
returnStandardCombineInput(hidden_states=output)
...
@@ -1211,7 +1219,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1211,7 +1219,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
elsetopk_config.correction_bias.to(x.dtype)
elsetopk_config.correction_bias.to(x.dtype)
)
)
returntrtllm_fp8_block_scale_moe(
withuse_symmetric_memory(get_tp_group())assm:
# FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
# It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
# so we put the whole function under the ``use_symmetric_memory`` context manager.
# If the bug is fixed, we can only put the output tensor allocation under the context manager.
output=trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
routing_logits=router_logits.to(torch.float32),
routing_bias=correction_bias,
routing_bias=correction_bias,
hidden_states=a_q,
hidden_states=a_q,
...
@@ -1236,6 +1249,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1236,6 +1249,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
# It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
# so we put the whole function under the ``use_symmetric_memory`` context manager.
# If the bug is fixed, we can only put the output tensor allocation under the context manager.
output=trtllm_fp8_per_tensor_scale_moe(
output=trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits_cast,
routing_logits=routing_logits_cast,
routing_bias=routing_bias_cast,
routing_bias=routing_bias_cast,
...
@@ -676,12 +684,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -676,12 +684,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM10 supports world size 6, 8.",
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM100 supports world size 6, 8.",