Unverified Commit 0140eafb authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

[Bug] Fix FlashInfer allreduce fusion workspace uninitialized error (#37461)


Signed-off-by: default avatarroot <root@prenyx0169.a51.clusters.nvidia.com>
Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: <>
Co-authored-by: default avatarroot <root@prenyx0169.a51.clusters.nvidia.com>
Co-authored-by: default avatarroot <root@prenyx0042.a51.clusters.nvidia.com>
parent bdf6a0a5
...@@ -86,8 +86,6 @@ if flashinfer_comm is not None: ...@@ -86,8 +86,6 @@ if flashinfer_comm is not None:
destroy_fi_ar_workspace, destroy_fi_ar_workspace,
get_fi_ar_quant_workspace, get_fi_ar_quant_workspace,
get_fi_ar_workspace, get_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
initialize_fi_ar_workspace,
) )
ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern
...@@ -133,15 +131,23 @@ if flashinfer_comm is not None: ...@@ -133,15 +131,23 @@ if flashinfer_comm is not None:
# Select workspace based on pattern: quant patterns use the # Select workspace based on pattern: quant patterns use the
# trtllm quant workspace, non-quant patterns use the primary workspace. # trtllm quant workspace, non-quant patterns use the primary workspace.
if pattern_code in ( is_quant_pattern = pattern_code in (
ar_fusion_patterns.kARResidualRMSNormFP8Quant, ar_fusion_patterns.kARResidualRMSNormFP8Quant,
ar_fusion_patterns.kARResidualRMSNormFP4Quant, ar_fusion_patterns.kARResidualRMSNormFP4Quant,
): )
workspace = get_fi_ar_quant_workspace() get_workspace_fn = (
else: get_fi_ar_quant_workspace if is_quant_pattern else get_fi_ar_workspace
workspace = get_fi_ar_workspace() )
workspace = get_workspace_fn(
world_size=world_size,
rank=get_tensor_model_parallel_rank(),
max_token_num=max_token_num,
hidden_dim=hidden_size,
dtype=allreduce_in.dtype,
group=get_tp_group().device_group,
)
assert workspace is not None, ( assert workspace is not None, (
"Flashinfer workspace must be initialized when using flashinfer" "Flashinfer allreduce workspace must be initialized when using flashinfer"
) )
assert flashinfer_comm is not None assert flashinfer_comm is not None
if norm_out is None: if norm_out is None:
...@@ -753,12 +759,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass): ...@@ -753,12 +759,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
scope="global", scope="global",
) )
for workspace_init_fn in [ workspace_kwargs = dict(
initialize_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
]:
try:
workspace_init_fn(
world_size=self.tp_size, world_size=self.tp_size,
rank=rank, rank=rank,
max_token_num=self.max_token_num, max_token_num=self.max_token_num,
...@@ -766,23 +767,22 @@ class AllReduceFusionPass(VllmPatternMatcherPass): ...@@ -766,23 +767,22 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
dtype=self.model_dtype, dtype=self.model_dtype,
group=self.group, group=self.group,
) )
except Exception as e: if get_fi_ar_workspace(**workspace_kwargs) is None:
if "multicast" in str(e).lower(): logger.warning_once(
logger.warning( "Failed to initialize Flashinfer allreduce workspace. "
"AllReduce fusion pass is disabled: flashinfer workspace " "Flashinfer allreduce-norm fusion will be disabled."
"creation failed: %s. This is expected on GPUs without "
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
"Falling back to non-fused allreduce.",
str(e),
)
else:
logger.warning(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"AllReduce fusion pass will be disabled.",
e,
) )
return return
self.supports_quant_fusion = (
get_fi_ar_quant_workspace(**workspace_kwargs) is not None
)
if not self.supports_quant_fusion:
logger.warning_once(
"Failed to initialize Flashinfer allreduce workspace. "
"Flashinfer allreduce-norm-quant fusion will be disabled."
)
self.allreduce_params = FlashInferFusedAllReduceParams( self.allreduce_params = FlashInferFusedAllReduceParams(
world_size=self.tp_size, world_size=self.tp_size,
max_token_num=self.max_token_num, max_token_num=self.max_token_num,
...@@ -793,9 +793,8 @@ class AllReduceFusionPass(VllmPatternMatcherPass): ...@@ -793,9 +793,8 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
@enable_fake_mode @enable_fake_mode
def register_patterns(self) -> None: def register_patterns(self) -> None:
supports_quantization = get_fi_ar_quant_workspace() is not None
for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:
if supports_quantization: if self.supports_quant_fusion:
AllReduceFusedRMSNormStaticQuantFP8Pattern( AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon, epsilon,
self.model_dtype, self.model_dtype,
......
...@@ -29,50 +29,27 @@ try: ...@@ -29,50 +29,27 @@ try:
except ImportError: except ImportError:
pass pass
# Global workspace for standalone allreduce and non-quant ar+rms fusion # Workspace for standalone allreduce and non-quant ar+rms fusion
_fi_ar_workspace = None _fi_ar_workspace = None
# Extra workspace for quant fusion patterns (only supported by trtllm backend) # Extra workspace for quant fusion patterns (only supported by trtllm backend)
# Only created if primary workspace is not already trtllm
_fi_ar_quant_workspace = None _fi_ar_quant_workspace = None
def get_fi_ar_workspace(): def _create_workspace(
return _fi_ar_workspace backend: str,
def get_fi_ar_quant_workspace():
return _fi_ar_quant_workspace
def initialize_fi_ar_workspace(
world_size: int, world_size: int,
rank: int, rank: int,
max_token_num: int, max_token_num: int,
hidden_dim: int, hidden_dim: int,
dtype: torch.dtype, dtype: torch.dtype,
group: ProcessGroup, group: ProcessGroup,
) -> None: ):
""" """Create a flashinfer allreduce workspace, returning None on failure."""
Initialize the workspace if not already initialized.
Currently, this function is called by either the AllReduceFusionPass
or the FlashInferAllReduce backend for standalone allreduce.
If the fusion pass is enabled via
--compilation-config.pass_config.fuse_allreduce_rms=true,
it will create the workspace first, and the standalone backend
will reuse the workspace. Otherwise, the standalone backend will
create the workspace.
"""
global _fi_ar_workspace
if _fi_ar_workspace is not None:
return
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
comm_backend = TorchDistBackend(group=group) comm_backend = TorchDistBackend(group=group)
rng_state = random.getstate() rng_state = random.getstate()
try: try:
random.seed(int.from_bytes(os.urandom(16), byteorder="big")) random.seed(int.from_bytes(os.urandom(16), byteorder="big"))
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace( workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend=backend, backend=backend,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
...@@ -81,9 +58,22 @@ def initialize_fi_ar_workspace( ...@@ -81,9 +58,22 @@ def initialize_fi_ar_workspace(
dtype=dtype, dtype=dtype,
comm_backend=comm_backend, comm_backend=comm_backend,
) )
except Exception as e:
if "multicast" in str(e).lower():
logger.warning_once(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"This is expected on GPUs without NVSwitch (e.g., NVLink "
"bridge-only or PCIe topologies).",
e,
)
else:
logger.warning_once(
"Failed to initialize FlashInfer All Reduce workspace: %s.",
e,
)
return None
finally: finally:
random.setstate(rng_state) random.setstate(rng_state)
assert _fi_ar_workspace is not None
logger.debug( logger.debug(
"Initialized FlashInfer All Reduce workspace: backend=%s, " "Initialized FlashInfer All Reduce workspace: backend=%s, "
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s", "world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
...@@ -94,70 +84,84 @@ def initialize_fi_ar_workspace( ...@@ -94,70 +84,84 @@ def initialize_fi_ar_workspace(
hidden_dim, hidden_dim,
dtype, dtype,
) )
return workspace
def initialize_fi_ar_quant_workspace( def get_fi_ar_workspace(
world_size: int, world_size: int,
rank: int, rank: int,
max_token_num: int, max_token_num: int,
hidden_dim: int, hidden_dim: int,
dtype: torch.dtype, dtype: torch.dtype,
group: ProcessGroup, group: ProcessGroup,
) -> None: ):
""" """
Initialize the workspace used by quantization fusion patterns. Return the allreduce workspace for non-quant patterns, initializing if needed.
Currently this always creates a workspace for trtllm backend as only it Used by AllReduceFusionPass (non-quant patterns) and FlashInferAllReduce
supports quantization fusion (FP8/FP4). If the primary workspace for standalone allreduce. Backend is controlled by
is already trtllm, the quant workspace aliases to it. VLLM_FLASHINFER_ALLREDUCE_BACKEND env var.
"""
global _fi_ar_workspace
if _fi_ar_workspace is not None:
return _fi_ar_workspace
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
# Reuse the quant workspace if it was already created with the same backend
if _fi_ar_quant_workspace is not None and _fi_ar_quant_workspace.backend == backend:
_fi_ar_workspace = _fi_ar_quant_workspace
return _fi_ar_workspace
_fi_ar_workspace = _create_workspace(
backend, world_size, rank, max_token_num, hidden_dim, dtype, group
)
return _fi_ar_workspace
def get_fi_ar_quant_workspace(
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
group: ProcessGroup,
):
"""
Return the allreduce workspace for quant patterns, initializing if needed.
Always uses trtllm backend as it is the only one supporting quantization
fusion (FP8/FP4).
""" """
global _fi_ar_quant_workspace global _fi_ar_quant_workspace
if _fi_ar_quant_workspace is not None: if _fi_ar_quant_workspace is not None:
return return _fi_ar_quant_workspace
# If primary workspace is already trtllm, reuse it # Reuse the non-quant workspace if it was already created with trtllm
if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm": if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
_fi_ar_quant_workspace = _fi_ar_workspace _fi_ar_quant_workspace = _fi_ar_workspace
return return _fi_ar_quant_workspace
comm_backend = TorchDistBackend(group=group) _fi_ar_quant_workspace = _create_workspace(
_fi_ar_quant_workspace = flashinfer_comm.create_allreduce_fusion_workspace( "trtllm", world_size, rank, max_token_num, hidden_dim, dtype, group
backend="trtllm",
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
comm_backend=comm_backend,
)
assert _fi_ar_quant_workspace is not None
logger.debug(
"Initialized FlashInfer All Reduce workspace: backend=trtllm, "
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
world_size,
rank,
max_token_num,
hidden_dim,
dtype,
) )
return _fi_ar_quant_workspace
_fi_ar_workspace_lock = threading.Lock() _fi_ar_workspace_lock = threading.Lock()
def destroy_fi_ar_workspace(): def destroy_fi_ar_workspace():
global _fi_ar_workspace global _fi_ar_workspace, _fi_ar_quant_workspace
global _fi_ar_quant_workspace
with _fi_ar_workspace_lock: with _fi_ar_workspace_lock:
if ( is_alias = _fi_ar_workspace is _fi_ar_quant_workspace
_fi_ar_quant_workspace is not None
and _fi_ar_quant_workspace is not _fi_ar_workspace
):
_fi_ar_quant_workspace.destroy()
_fi_ar_quant_workspace = None
if _fi_ar_workspace is not None: if _fi_ar_workspace is not None:
_fi_ar_workspace.destroy() _fi_ar_workspace.destroy()
_fi_ar_workspace = None if _fi_ar_quant_workspace is not None and not is_alias:
_fi_ar_quant_workspace.destroy()
_fi_ar_workspace = _fi_ar_quant_workspace = None
atexit.register(destroy_fi_ar_workspace) atexit.register(destroy_fi_ar_workspace)
...@@ -209,13 +213,10 @@ class FlashInferAllReduce: ...@@ -209,13 +213,10 @@ class FlashInferAllReduce:
def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool: def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
"""Ensure the all reduce workspace is initialized.""" """Ensure the all reduce workspace is initialized."""
if get_fi_ar_workspace() is not None:
return True
if self.max_num_tokens == 0: if self.max_num_tokens == 0:
element_size = torch.tensor([], dtype=dtype, device="cpu").element_size() element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size) self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
try: workspace = get_fi_ar_workspace(
initialize_fi_ar_workspace(
world_size=self.world_size, world_size=self.world_size,
rank=self.rank, rank=self.rank,
max_token_num=self.max_num_tokens, max_token_num=self.max_num_tokens,
...@@ -223,15 +224,10 @@ class FlashInferAllReduce: ...@@ -223,15 +224,10 @@ class FlashInferAllReduce:
dtype=dtype, dtype=dtype,
group=self.group, group=self.group,
) )
return True if workspace is None:
except Exception as e:
logger.warning(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"FlashInfer All Reduce will be disabled.",
e,
)
self.disabled = True self.disabled = True
return False return False
return True
def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool: def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
if self.disabled: if self.disabled:
...@@ -257,7 +253,15 @@ class FlashInferAllReduce: ...@@ -257,7 +253,15 @@ class FlashInferAllReduce:
return self._ensure_workspace(hidden_dim, input_tensor.dtype) return self._ensure_workspace(hidden_dim, input_tensor.dtype)
def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor: def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor:
workspace = get_fi_ar_workspace() _, hidden_dim = input_tensor.shape
workspace = get_fi_ar_workspace(
world_size=self.world_size,
rank=self.rank,
max_token_num=self.max_num_tokens,
hidden_dim=hidden_dim,
dtype=input_tensor.dtype,
group=self.group,
)
return flashinfer_comm.allreduce_fusion( return flashinfer_comm.allreduce_fusion(
input=input_tensor, input=input_tensor,
workspace=workspace, workspace=workspace,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment