Unverified Commit 0788ff0a authored by haosdent's avatar haosdent Committed by GitHub
Browse files

[Bugfix] Gracefully disable AllReduceFusionPass on GPUs without multicast support (#35085)


Signed-off-by: default avatarhaosdent <haosdent@gmail.com>
parent d72b0be3
...@@ -729,6 +729,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass): ...@@ -729,6 +729,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
scope="global", scope="global",
) )
try:
self.workspace = flashinfer_comm.create_allreduce_fusion_workspace( self.workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm", backend="trtllm",
world_size=self.tp_size, world_size=self.tp_size,
...@@ -737,6 +738,17 @@ class AllReduceFusionPass(VllmPatternMatcherPass): ...@@ -737,6 +738,17 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
hidden_dim=self.hidden_dim, hidden_dim=self.hidden_dim,
dtype=self.model_dtype, dtype=self.model_dtype,
) )
except RuntimeError as e:
if "multicast" not in str(e).lower():
raise
logger.warning_once(
"AllReduce fusion pass is disabled: flashinfer workspace "
"creation failed: %s. This is expected on GPUs without "
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
"Falling back to non-fused allreduce.",
str(e),
)
return
global _FI_WORKSPACE global _FI_WORKSPACE
_FI_WORKSPACE = self.workspace _FI_WORKSPACE = self.workspace
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment