Commit 376c265f authored by Myle Ott's avatar Myle Ott
Browse files

Add support for NCCL v2

parent 8bafae2e
......@@ -12,9 +12,10 @@ GPU separately.
"""
import ctypes
import warnings
from ctypes.util import find_library
lib = None
nccl_2_0 = None
_uid = None
_rank = None
_num_devices = None
......@@ -22,48 +23,25 @@ _comm = None
__all__ = ['all_reduce', 'initialize', 'get_unique_id']
def _libnccl():
global lib
if lib is None:
lib = ctypes.cdll.LoadLibrary(None)
if hasattr(lib, 'ncclCommDestroy'):
lib.ncclCommDestroy.restype = None
lib.ncclGetErrorString.restype = ctypes.c_char_p
else:
lib = None
return lib
def is_available(tensors):
devices = set()
for tensor in tensors:
if not tensor.is_contiguous():
return False
if not tensor.is_cuda:
return False
device = tensor.get_device()
if device in devices:
return False
devices.add(device)
if _libnccl() is None:
warnings.warn('NCCL library not found. Check your LD_LIBRARY_PATH')
return False
return True
_communicators = {}
# ncclDataType_t
ncclChar = 0
ncclInt = 1
ncclHalf = 2
ncclFloat = 3
ncclDouble = 4
ncclInt64 = 5
ncclUint64 = 6
nccl_types = {
'torch.cuda.ByteTensor': 0,
'torch.cuda.CharTensor': 0,
'torch.cuda.IntTensor': 1,
'torch.cuda.HalfTensor': 2,
'torch.cuda.FloatTensor': 3,
'torch.cuda.DoubleTensor': 4,
'torch.cuda.LongTensor': 5,
}
nccl_types_2_0 = {
'torch.cuda.ByteTensor': 0,
'torch.cuda.CharTensor': 0,
'torch.cuda.IntTensor': 2,
'torch.cuda.HalfTensor': 6,
'torch.cuda.FloatTensor': 7,
'torch.cuda.DoubleTensor': 8,
'torch.cuda.LongTensor': 4,
}
# ncclRedOp_t
SUM = 0
......@@ -71,21 +49,57 @@ PROD = 1
MAX = 2
MIN = 3
nccl_types = {
'torch.cuda.ByteTensor': ncclChar,
'torch.cuda.CharTensor': ncclChar,
'torch.cuda.IntTensor': ncclInt,
'torch.cuda.HalfTensor': ncclHalf,
'torch.cuda.FloatTensor': ncclFloat,
'torch.cuda.DoubleTensor': ncclDouble,
'torch.cuda.LongTensor': ncclInt64,
status_codes_2_0 = {
0: "Success",
1: "Unhandled Cuda Error",
2: "System Error",
3: "Internal Error",
4: "Invalid Argument Error",
5: "Invalid Usage Error",
}
status_codes = {
0: "Success",
1: "Unhandled Cuda Error",
2: "System Error",
3: "Internal Error",
4: "Invalid Device Pointer",
5: "Invalid Rank",
6: "Unsupported Device Count",
7: "Device Not Found",
8: "Invalid Device Index",
9: "Lib Wrapper Not Set",
10: "Cuda Malloc Failed",
11: "Rank Mismatch",
12: "Invalid Argument",
13: "Invalid Type",
14: "Invalid Operation",
}
def _libnccl():
global nccl_2_0
global lib
global status_codes
global nccl_types
if lib is None:
lib = ctypes.pydll.LoadLibrary(find_library('nccl'))
if hasattr(lib, 'ncclCommDestroy'):
lib.ncclCommDestroy.restype = None
else:
lib = None
if hasattr(lib, 'ncclGroupStart'):
nccl_2_0 = True
status_codes = status_codes_2_0
nccl_types = nccl_types_2_0
return lib
class NcclError(RuntimeError):
def __init__(self, status):
self.status = status
msg = '{0} ({1})'.format(lib.ncclGetErrorString(status), status)
msg = '{0} ({1})'.format(status_codes.get(status), status)
super(NcclError, self).__init__(msg)
......@@ -134,10 +148,12 @@ def initialize(num_devices, uid, rank):
def communicator():
global _comm
if _libnccl() is None:
raise RuntimeError('Unable to load NCCL library')
if _uid is None:
raise RuntimeError('NCCL not initialized')
if _comm is None:
comm = ctypes.c_void_p()
comm = NcclComm()
check_error(lib.ncclCommInitRank(
ctypes.byref(comm),
ctypes.c_int(_num_devices),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment