Unverified Commit 0788ff0a authored by haosdent's avatar haosdent Committed by GitHub
Browse files

[Bugfix] Gracefully disable AllReduceFusionPass on GPUs without multicast support (#35085)


Signed-off-by: default avatarhaosdent <haosdent@gmail.com>
parent d72b0be3
...@@ -729,14 +729,26 @@ class AllReduceFusionPass(VllmPatternMatcherPass): ...@@ -729,14 +729,26 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
scope="global", scope="global",
) )
self.workspace = flashinfer_comm.create_allreduce_fusion_workspace( try:
backend="trtllm", self.workspace = flashinfer_comm.create_allreduce_fusion_workspace(
world_size=self.tp_size, backend="trtllm",
rank=rank, world_size=self.tp_size,
max_token_num=self.max_token_num, rank=rank,
hidden_dim=self.hidden_dim, max_token_num=self.max_token_num,
dtype=self.model_dtype, hidden_dim=self.hidden_dim,
) dtype=self.model_dtype,
)
except RuntimeError as e:
if "multicast" not in str(e).lower():
raise
logger.warning_once(
"AllReduce fusion pass is disabled: flashinfer workspace "
"creation failed: %s. This is expected on GPUs without "
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
"Falling back to non-fused allreduce.",
str(e),
)
return
global _FI_WORKSPACE global _FI_WORKSPACE
_FI_WORKSPACE = self.workspace _FI_WORKSPACE = self.workspace
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment