Unverified Commit 747b1a71 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Core][Distributed] fix _is_full_nvlink detection (#4233)

parent 95e5b087
import os
from contextlib import contextmanager
from typing import Optional
from typing import List, Optional
import torch
import torch.distributed as dist
......@@ -53,14 +54,20 @@ def init_custom_ar() -> None:
return False
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
full_nvlink = _is_full_nvlink(rank, world_size)
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = list(
map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(",")))
else:
device_ids = list(range(num_dev))
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(device_ids)
if world_size > 2 and not full_nvlink:
logger.warn(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
# test P2P capability
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
......@@ -138,22 +145,27 @@ def _nvml():
pynvml.nvmlShutdown()
# query if the set of gpus are fully connected by nvlink (1 hop)
@_nvml()
def _is_full_nvlink(rank, world_size):
handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
for i in range(world_size):
if i != rank:
def _is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
so it works on real physical device ids.
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i)
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.info(
f"NVLink detection failed with message \"{str(error)}\". "
"This is normal if your machine has no NVLink equipped")
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment