Unverified Commit 0904b655 authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

Fix multi-node allreduce fusion (#38136)


Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: default avatarroot <root@theia0053.lyris.clusters.nvidia.com>
parent f26fcdfb
...@@ -13,11 +13,13 @@ from torch.distributed import ProcessGroup ...@@ -13,11 +13,13 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs import vllm.envs as envs
from vllm.config.compilation import PassConfig from vllm.config.compilation import PassConfig
from vllm.distributed.parallel_state import get_node_count
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
fi_ar_available = False fi_ar_available = False
try: try:
import flashinfer.comm as flashinfer_comm # type: ignore[no-redef] import flashinfer.comm as flashinfer_comm # type: ignore[no-redef]
...@@ -87,6 +89,27 @@ def _create_workspace( ...@@ -87,6 +89,27 @@ def _create_workspace(
return workspace return workspace
def _resolve_fi_ar_backend() -> str:
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
if backend != "auto":
logger.info_once(f"Using flashinfer allreduce backend: {backend}")
return backend
if get_node_count() > 1: # noqa: SIM108
# Use mnnvl backend for multi-node setup since
# trtllm backend does not support multi-node allreduce
backend = "mnnvl"
else:
# Currently defaulting to trtllm backend for single-node
# setup since mnnvl has issues with cudagraph:
# https://github.com/vllm-project/vllm/issues/35772
# Should switch back to auto when the issue is resolved.
backend = "trtllm"
logger.info_once(f"Auto-selected flashinfer allreduce backend: {backend}")
return backend
def get_fi_ar_workspace( def get_fi_ar_workspace(
world_size: int, world_size: int,
rank: int, rank: int,
...@@ -106,7 +129,13 @@ def get_fi_ar_workspace( ...@@ -106,7 +129,13 @@ def get_fi_ar_workspace(
if _fi_ar_workspace is not None: if _fi_ar_workspace is not None:
return _fi_ar_workspace return _fi_ar_workspace
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND backend = _resolve_fi_ar_backend()
if get_node_count() > 1 and backend == "trtllm":
raise ValueError(
"Flashinfer allreduce is not supported for multi-node allreduce with "
"'trtllm' backend. Please use 'mnnvl' backend instead."
)
# Reuse the quant workspace if it was already created with the same backend # Reuse the quant workspace if it was already created with the same backend
if _fi_ar_quant_workspace is not None and _fi_ar_quant_workspace.backend == backend: if _fi_ar_quant_workspace is not None and _fi_ar_quant_workspace.backend == backend:
...@@ -116,6 +145,17 @@ def get_fi_ar_workspace( ...@@ -116,6 +145,17 @@ def get_fi_ar_workspace(
_fi_ar_workspace = _create_workspace( _fi_ar_workspace = _create_workspace(
backend, world_size, rank, max_token_num, hidden_dim, dtype, group backend, world_size, rank, max_token_num, hidden_dim, dtype, group
) )
if _fi_ar_workspace is not None:
logger.info_once(
"Initialized FlashInfer Allreduce norm fusion workspace "
f"with backend={backend}"
)
else:
logger.warning_once(
"Failed to initialize FlashInfer Allreduce norm fusion workspace "
f"with backend={backend}"
)
return _fi_ar_workspace return _fi_ar_workspace
...@@ -131,12 +171,20 @@ def get_fi_ar_quant_workspace( ...@@ -131,12 +171,20 @@ def get_fi_ar_quant_workspace(
Return the allreduce workspace for quant patterns, initializing if needed. Return the allreduce workspace for quant patterns, initializing if needed.
Always uses trtllm backend as it is the only one supporting quantization Always uses trtllm backend as it is the only one supporting quantization
fusion (FP8/FP4). fusion (FP8/FP4). Returns None for multi-node setups since not supported
by trtllm backend.
""" """
global _fi_ar_quant_workspace global _fi_ar_quant_workspace
if _fi_ar_quant_workspace is not None: if _fi_ar_quant_workspace is not None:
return _fi_ar_quant_workspace return _fi_ar_quant_workspace
if get_node_count() > 1:
logger.warning_once(
"Flashinfer allreduce quantization fusion is not supported for "
"multi-node allreduce. Disabling quant fusion."
)
return None
# Reuse the non-quant workspace if it was already created with trtllm # Reuse the non-quant workspace if it was already created with trtllm
if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm": if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
_fi_ar_quant_workspace = _fi_ar_workspace _fi_ar_quant_workspace = _fi_ar_workspace
...@@ -145,6 +193,17 @@ def get_fi_ar_quant_workspace( ...@@ -145,6 +193,17 @@ def get_fi_ar_quant_workspace(
_fi_ar_quant_workspace = _create_workspace( _fi_ar_quant_workspace = _create_workspace(
"trtllm", world_size, rank, max_token_num, hidden_dim, dtype, group "trtllm", world_size, rank, max_token_num, hidden_dim, dtype, group
) )
if _fi_ar_quant_workspace is not None:
logger.info_once(
"Initialized FlashInfer Allreduce norm quantization "
"fusion workspace with backend=trtllm"
)
else:
logger.warning_once(
"Failed to initialize FlashInfer Allreduce norm quantization "
"fusion workspace with backend=trtllm"
)
return _fi_ar_quant_workspace return _fi_ar_quant_workspace
......
...@@ -169,7 +169,7 @@ if TYPE_CHECKING: ...@@ -169,7 +169,7 @@ if TYPE_CHECKING:
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = ( VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
"latency" "latency"
) )
VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "trtllm" VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto"
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024 VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
...@@ -1305,14 +1305,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1305,14 +1305,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
["throughput", "latency", "masked_gemm"], ["throughput", "latency", "masked_gemm"],
), ),
# Flashinfer fused allreduce backend. # Flashinfer fused allreduce backend.
# "auto" will default to "mnnvl", which performs mostly same/better than "trtllm".
# But "mnnvl" backend does not support fuse with quantization.
# TODO: Default is "trtllm" right now because "mnnvl" has issues with cudagraph:
# https://github.com/vllm-project/vllm/issues/35772
# Should switch back to "auto" if the issue is resolved.
"VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices( "VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices(
"VLLM_FLASHINFER_ALLREDUCE_BACKEND", "VLLM_FLASHINFER_ALLREDUCE_BACKEND",
"trtllm", "auto",
["auto", "trtllm", "mnnvl"], ["auto", "trtllm", "mnnvl"],
), ),
# Control the workspace buffer size for the FlashInfer backend. # Control the workspace buffer size for the FlashInfer backend.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment