Unverified Commit 0140eafb authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

[Bug] Fix FlashInfer allreduce fusion workspace uninitialized error (#37461)


Signed-off-by: default avatarroot <root@prenyx0169.a51.clusters.nvidia.com>
Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: <>
Co-authored-by: default avatarroot <root@prenyx0169.a51.clusters.nvidia.com>
Co-authored-by: default avatarroot <root@prenyx0042.a51.clusters.nvidia.com>
parent bdf6a0a5
......@@ -86,8 +86,6 @@ if flashinfer_comm is not None:
destroy_fi_ar_workspace,
get_fi_ar_quant_workspace,
get_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
initialize_fi_ar_workspace,
)
ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern
......@@ -133,15 +131,23 @@ if flashinfer_comm is not None:
# Select workspace based on pattern: quant patterns use the
# trtllm quant workspace, non-quant patterns use the primary workspace.
if pattern_code in (
is_quant_pattern = pattern_code in (
ar_fusion_patterns.kARResidualRMSNormFP8Quant,
ar_fusion_patterns.kARResidualRMSNormFP4Quant,
):
workspace = get_fi_ar_quant_workspace()
else:
workspace = get_fi_ar_workspace()
)
get_workspace_fn = (
get_fi_ar_quant_workspace if is_quant_pattern else get_fi_ar_workspace
)
workspace = get_workspace_fn(
world_size=world_size,
rank=get_tensor_model_parallel_rank(),
max_token_num=max_token_num,
hidden_dim=hidden_size,
dtype=allreduce_in.dtype,
group=get_tp_group().device_group,
)
assert workspace is not None, (
"Flashinfer workspace must be initialized when using flashinfer"
"Flashinfer allreduce workspace must be initialized when using flashinfer"
)
assert flashinfer_comm is not None
if norm_out is None:
......@@ -753,35 +759,29 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
scope="global",
)
for workspace_init_fn in [
initialize_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
]:
try:
workspace_init_fn(
world_size=self.tp_size,
rank=rank,
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
dtype=self.model_dtype,
group=self.group,
)
except Exception as e:
if "multicast" in str(e).lower():
logger.warning(
"AllReduce fusion pass is disabled: flashinfer workspace "
"creation failed: %s. This is expected on GPUs without "
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
"Falling back to non-fused allreduce.",
str(e),
)
else:
logger.warning(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"AllReduce fusion pass will be disabled.",
e,
)
return
workspace_kwargs = dict(
world_size=self.tp_size,
rank=rank,
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
dtype=self.model_dtype,
group=self.group,
)
if get_fi_ar_workspace(**workspace_kwargs) is None:
logger.warning_once(
"Failed to initialize Flashinfer allreduce workspace. "
"Flashinfer allreduce-norm fusion will be disabled."
)
return
self.supports_quant_fusion = (
get_fi_ar_quant_workspace(**workspace_kwargs) is not None
)
if not self.supports_quant_fusion:
logger.warning_once(
"Failed to initialize Flashinfer allreduce workspace. "
"Flashinfer allreduce-norm-quant fusion will be disabled."
)
self.allreduce_params = FlashInferFusedAllReduceParams(
world_size=self.tp_size,
......@@ -793,9 +793,8 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
@enable_fake_mode
def register_patterns(self) -> None:
supports_quantization = get_fi_ar_quant_workspace() is not None
for epsilon in [1e-5, 1e-6]:
if supports_quantization:
if self.supports_quant_fusion:
AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
......
......@@ -29,50 +29,27 @@ try:
except ImportError:
pass
# Global workspace for standalone allreduce and non-quant ar+rms fusion
# Workspace for standalone allreduce and non-quant ar+rms fusion
_fi_ar_workspace = None
# Extra workspace for quant fusion patterns (only supported by trtllm backend)
# Only created if primary workspace is not already trtllm
_fi_ar_quant_workspace = None
def get_fi_ar_workspace():
return _fi_ar_workspace
def get_fi_ar_quant_workspace():
return _fi_ar_quant_workspace
def initialize_fi_ar_workspace(
def _create_workspace(
backend: str,
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
group: ProcessGroup,
) -> None:
"""
Initialize the workspace if not already initialized.
Currently, this function is called by either the AllReduceFusionPass
or the FlashInferAllReduce backend for standalone allreduce.
If the fusion pass is enabled via
--compilation-config.pass_config.fuse_allreduce_rms=true,
it will create the workspace first, and the standalone backend
will reuse the workspace. Otherwise, the standalone backend will
create the workspace.
"""
global _fi_ar_workspace
if _fi_ar_workspace is not None:
return
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
):
"""Create a flashinfer allreduce workspace, returning None on failure."""
comm_backend = TorchDistBackend(group=group)
rng_state = random.getstate()
try:
random.seed(int.from_bytes(os.urandom(16), byteorder="big"))
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend=backend,
world_size=world_size,
rank=rank,
......@@ -81,9 +58,22 @@ def initialize_fi_ar_workspace(
dtype=dtype,
comm_backend=comm_backend,
)
except Exception as e:
if "multicast" in str(e).lower():
logger.warning_once(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"This is expected on GPUs without NVSwitch (e.g., NVLink "
"bridge-only or PCIe topologies).",
e,
)
else:
logger.warning_once(
"Failed to initialize FlashInfer All Reduce workspace: %s.",
e,
)
return None
finally:
random.setstate(rng_state)
assert _fi_ar_workspace is not None
logger.debug(
"Initialized FlashInfer All Reduce workspace: backend=%s, "
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
......@@ -94,70 +84,84 @@ def initialize_fi_ar_workspace(
hidden_dim,
dtype,
)
return workspace
def get_fi_ar_workspace(
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
group: ProcessGroup,
):
"""
Return the allreduce workspace for non-quant patterns, initializing if needed.
Used by AllReduceFusionPass (non-quant patterns) and FlashInferAllReduce
for standalone allreduce. Backend is controlled by
VLLM_FLASHINFER_ALLREDUCE_BACKEND env var.
"""
global _fi_ar_workspace
if _fi_ar_workspace is not None:
return _fi_ar_workspace
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
# Reuse the quant workspace if it was already created with the same backend
if _fi_ar_quant_workspace is not None and _fi_ar_quant_workspace.backend == backend:
_fi_ar_workspace = _fi_ar_quant_workspace
return _fi_ar_workspace
_fi_ar_workspace = _create_workspace(
backend, world_size, rank, max_token_num, hidden_dim, dtype, group
)
return _fi_ar_workspace
def initialize_fi_ar_quant_workspace(
def get_fi_ar_quant_workspace(
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
group: ProcessGroup,
) -> None:
):
"""
Initialize the workspace used by quantization fusion patterns.
Return the allreduce workspace for quant patterns, initializing if needed.
Currently this always creates a workspace for trtllm backend as only it
supports quantization fusion (FP8/FP4). If the primary workspace
is already trtllm, the quant workspace aliases to it.
Always uses trtllm backend as it is the only one supporting quantization
fusion (FP8/FP4).
"""
global _fi_ar_quant_workspace
if _fi_ar_quant_workspace is not None:
return
return _fi_ar_quant_workspace
# If primary workspace is already trtllm, reuse it
# Reuse the non-quant workspace if it was already created with trtllm
if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
_fi_ar_quant_workspace = _fi_ar_workspace
return
return _fi_ar_quant_workspace
comm_backend = TorchDistBackend(group=group)
_fi_ar_quant_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
comm_backend=comm_backend,
)
assert _fi_ar_quant_workspace is not None
logger.debug(
"Initialized FlashInfer All Reduce workspace: backend=trtllm, "
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
world_size,
rank,
max_token_num,
hidden_dim,
dtype,
_fi_ar_quant_workspace = _create_workspace(
"trtllm", world_size, rank, max_token_num, hidden_dim, dtype, group
)
return _fi_ar_quant_workspace
_fi_ar_workspace_lock = threading.Lock()
def destroy_fi_ar_workspace():
global _fi_ar_workspace
global _fi_ar_quant_workspace
global _fi_ar_workspace, _fi_ar_quant_workspace
with _fi_ar_workspace_lock:
if (
_fi_ar_quant_workspace is not None
and _fi_ar_quant_workspace is not _fi_ar_workspace
):
_fi_ar_quant_workspace.destroy()
_fi_ar_quant_workspace = None
is_alias = _fi_ar_workspace is _fi_ar_quant_workspace
if _fi_ar_workspace is not None:
_fi_ar_workspace.destroy()
_fi_ar_workspace = None
if _fi_ar_quant_workspace is not None and not is_alias:
_fi_ar_quant_workspace.destroy()
_fi_ar_workspace = _fi_ar_quant_workspace = None
atexit.register(destroy_fi_ar_workspace)
......@@ -209,29 +213,21 @@ class FlashInferAllReduce:
def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
"""Ensure the all reduce workspace is initialized."""
if get_fi_ar_workspace() is not None:
return True
if self.max_num_tokens == 0:
element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
try:
initialize_fi_ar_workspace(
world_size=self.world_size,
rank=self.rank,
max_token_num=self.max_num_tokens,
hidden_dim=hidden_dim,
dtype=dtype,
group=self.group,
)
return True
except Exception as e:
logger.warning(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"FlashInfer All Reduce will be disabled.",
e,
)
workspace = get_fi_ar_workspace(
world_size=self.world_size,
rank=self.rank,
max_token_num=self.max_num_tokens,
hidden_dim=hidden_dim,
dtype=dtype,
group=self.group,
)
if workspace is None:
self.disabled = True
return False
return True
def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
if self.disabled:
......@@ -257,7 +253,15 @@ class FlashInferAllReduce:
return self._ensure_workspace(hidden_dim, input_tensor.dtype)
def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor:
workspace = get_fi_ar_workspace()
_, hidden_dim = input_tensor.shape
workspace = get_fi_ar_workspace(
world_size=self.world_size,
rank=self.rank,
max_token_num=self.max_num_tokens,
hidden_dim=hidden_dim,
dtype=input_tensor.dtype,
group=self.group,
)
return flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=workspace,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment