Unverified Commit 31fb6f43 authored by pkousha's avatar pkousha Committed by GitHub
Browse files

[Kernel][perf] optimize NCCL symm_mem vs custom_AR selection thresholds (#33839)



Signed-off-by: <>
Signed-off-by: default avatarpkousha <43781676+pkousha@users.noreply.github.com>
Co-authored-by: default avatarPouya Kousha <pkousha@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent eb19955c
...@@ -27,6 +27,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless ...@@ -27,6 +27,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__) logger = init_logger(__name__)
KiB = 1024
MiB = 1024 * 1024 MiB = 1024 * 1024
# Max size for each world size in case symmetric memory is available # Max size for each world size in case symmetric memory is available
# For different SM architectures # For different SM architectures
...@@ -60,17 +61,44 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = { ...@@ -60,17 +61,44 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
}, },
} }
# NCCL symmetric memory allreduce configuration based on H100 and GB200 benchmarks.
# PyNCCL-symm outperforms custom_AR for small and large tensor sizes,
# while custom_AR wins for mid-range sizes.
#
# Benchmark results (8 GPUs):
# 2K - 16K: PyNCCL-symm wins (1.35x - 1.48x faster)
# 32K - 64K: custom_AR wins
# 128K - 1G: PyNCCL-symm wins (1.12x - 6.14x faster)
#
# Benchmark results (4 GPUs):
# 2K - 16K: PyNCCL-symm wins (1.21x - 1.30x faster)
# 32K - 256K: custom_AR wins (1.07x - 1.35x faster)
# 512K - 1G: PyNCCL-symm wins (1.10x - 2.32x faster)
#
# The config defines ranges where custom_AR is preferred (symm_mem disabled).
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = { NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
"min_world_size": 4, "min_world_size": 4,
"thresholds": { # Ranges where custom_AR outperforms NCCL symm_mem: (lower_bound, upper_bound)
4: 2 * MiB, # 2 MB # NCCL symm_mem will NOT be used for sizes in range: lower < size < upper
8: 1 * MiB, # 1 MB "custom_ar_preferred_ranges": {
4: (16 * KiB, 512 * KiB), # custom_AR wins for 32K-256K
8: (16 * KiB, 128 * KiB), # custom_AR wins for 32K-64K
}, },
"always_use_above_world_size": 8, # Always use symm mem for world_size > 8 "always_use_above_world_size": 8, # Always use symm mem for world_size > 8
} }
def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool: def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool:
"""
Determine if NCCL symmetric memory allreduce should be used.
Based on H100 and GB200 benchmarks, NCCL symm_mem is preferred for:
- Small tensors (≤16K): Lower latency than custom_AR
- Large tensors (≥128K for 8 GPUs, ≥512K for 4 GPUs): Better bandwidth
Custom_AR is preferred for mid-range sizes where its P2P approach
has lower overhead than the symm_mem copy-in/copy-out pattern.
"""
from vllm.distributed.device_communicators.pynccl_allocator import ( from vllm.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled, is_symmetric_memory_enabled,
) )
...@@ -80,11 +108,20 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) ...@@ -80,11 +108,20 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
if not is_symmetric_memory_enabled(): if not is_symmetric_memory_enabled():
return False return False
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]: if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
return False return False
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
if threshold is not None and input_tensor.nbytes >= threshold: tensor_size = input_tensor.nbytes
return True custom_ar_range = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["custom_ar_preferred_ranges"].get(
world_size
)
if custom_ar_range is not None:
lower_bound, upper_bound = custom_ar_range
# Use symm_mem for small sizes (≤ lower_bound) and large sizes (≥ upper_bound)
# Use custom_AR (not symm_mem) for mid-range sizes
return tensor_size <= lower_bound or tensor_size >= upper_bound
return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"] return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment