Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
376c265f
Commit
376c265f
authored
Oct 02, 2017
by
Myle Ott
Browse files
Add support for NCCL v2
parent
8bafae2e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
52 deletions
+68
-52
fairseq/nccl.py
fairseq/nccl.py
+68
-52
No files found.
fairseq/nccl.py
View file @
376c265f
...
@@ -12,9 +12,10 @@ GPU separately.
...
@@ -12,9 +12,10 @@ GPU separately.
"""
"""
import
ctypes
import
ctypes
import
warnings
from
ctypes.util
import
find_library
lib
=
None
lib
=
None
nccl_2_0
=
None
_uid
=
None
_uid
=
None
_rank
=
None
_rank
=
None
_num_devices
=
None
_num_devices
=
None
...
@@ -22,48 +23,25 @@ _comm = None
...
@@ -22,48 +23,25 @@ _comm = None
__all__
=
[
'all_reduce'
,
'initialize'
,
'get_unique_id'
]
__all__
=
[
'all_reduce'
,
'initialize'
,
'get_unique_id'
]
def
_libnccl
():
global
lib
if
lib
is
None
:
lib
=
ctypes
.
cdll
.
LoadLibrary
(
None
)
if
hasattr
(
lib
,
'ncclCommDestroy'
):
lib
.
ncclCommDestroy
.
restype
=
None
lib
.
ncclGetErrorString
.
restype
=
ctypes
.
c_char_p
else
:
lib
=
None
return
lib
def
is_available
(
tensors
):
devices
=
set
()
for
tensor
in
tensors
:
if
not
tensor
.
is_contiguous
():
return
False
if
not
tensor
.
is_cuda
:
return
False
device
=
tensor
.
get_device
()
if
device
in
devices
:
return
False
devices
.
add
(
device
)
if
_libnccl
()
is
None
:
warnings
.
warn
(
'NCCL library not found. Check your LD_LIBRARY_PATH'
)
return
False
return
True
_communicators
=
{}
# ncclDataType_t
# ncclDataType_t
ncclChar
=
0
nccl_types
=
{
ncclInt
=
1
'torch.cuda.ByteTensor'
:
0
,
ncclHalf
=
2
'torch.cuda.CharTensor'
:
0
,
ncclFloat
=
3
'torch.cuda.IntTensor'
:
1
,
ncclDouble
=
4
'torch.cuda.HalfTensor'
:
2
,
ncclInt64
=
5
'torch.cuda.FloatTensor'
:
3
,
ncclUint64
=
6
'torch.cuda.DoubleTensor'
:
4
,
'torch.cuda.LongTensor'
:
5
,
}
nccl_types_2_0
=
{
'torch.cuda.ByteTensor'
:
0
,
'torch.cuda.CharTensor'
:
0
,
'torch.cuda.IntTensor'
:
2
,
'torch.cuda.HalfTensor'
:
6
,
'torch.cuda.FloatTensor'
:
7
,
'torch.cuda.DoubleTensor'
:
8
,
'torch.cuda.LongTensor'
:
4
,
}
# ncclRedOp_t
# ncclRedOp_t
SUM
=
0
SUM
=
0
...
@@ -71,21 +49,57 @@ PROD = 1
...
@@ -71,21 +49,57 @@ PROD = 1
MAX
=
2
MAX
=
2
MIN
=
3
MIN
=
3
nccl_types
=
{
status_codes_2_0
=
{
'torch.cuda.ByteTensor'
:
ncclChar
,
0
:
"Success"
,
'torch.cuda.CharTensor'
:
ncclChar
,
1
:
"Unhandled Cuda Error"
,
'torch.cuda.IntTensor'
:
ncclInt
,
2
:
"System Error"
,
'torch.cuda.HalfTensor'
:
ncclHalf
,
3
:
"Internal Error"
,
'torch.cuda.FloatTensor'
:
ncclFloat
,
4
:
"Invalid Argument Error"
,
'torch.cuda.DoubleTensor'
:
ncclDouble
,
5
:
"Invalid Usage Error"
,
'torch.cuda.LongTensor'
:
ncclInt64
,
}
status_codes
=
{
0
:
"Success"
,
1
:
"Unhandled Cuda Error"
,
2
:
"System Error"
,
3
:
"Internal Error"
,
4
:
"Invalid Device Pointer"
,
5
:
"Invalid Rank"
,
6
:
"Unsupported Device Count"
,
7
:
"Device Not Found"
,
8
:
"Invalid Device Index"
,
9
:
"Lib Wrapper Not Set"
,
10
:
"Cuda Malloc Failed"
,
11
:
"Rank Mismatch"
,
12
:
"Invalid Argument"
,
13
:
"Invalid Type"
,
14
:
"Invalid Operation"
,
}
}
def
_libnccl
():
global
nccl_2_0
global
lib
global
status_codes
global
nccl_types
if
lib
is
None
:
lib
=
ctypes
.
pydll
.
LoadLibrary
(
find_library
(
'nccl'
))
if
hasattr
(
lib
,
'ncclCommDestroy'
):
lib
.
ncclCommDestroy
.
restype
=
None
else
:
lib
=
None
if
hasattr
(
lib
,
'ncclGroupStart'
):
nccl_2_0
=
True
status_codes
=
status_codes_2_0
nccl_types
=
nccl_types_2_0
return
lib
class
NcclError
(
RuntimeError
):
class
NcclError
(
RuntimeError
):
def
__init__
(
self
,
status
):
def
__init__
(
self
,
status
):
self
.
status
=
status
self
.
status
=
status
msg
=
'{0} ({1})'
.
format
(
lib
.
ncclGetErrorString
(
status
),
status
)
msg
=
'{0} ({1})'
.
format
(
status_codes
.
get
(
status
),
status
)
super
(
NcclError
,
self
).
__init__
(
msg
)
super
(
NcclError
,
self
).
__init__
(
msg
)
...
@@ -134,10 +148,12 @@ def initialize(num_devices, uid, rank):
...
@@ -134,10 +148,12 @@ def initialize(num_devices, uid, rank):
def
communicator
():
def
communicator
():
global
_comm
global
_comm
if
_libnccl
()
is
None
:
raise
RuntimeError
(
'Unable to load NCCL library'
)
if
_uid
is
None
:
if
_uid
is
None
:
raise
RuntimeError
(
'NCCL not initialized'
)
raise
RuntimeError
(
'NCCL not initialized'
)
if
_comm
is
None
:
if
_comm
is
None
:
comm
=
ctypes
.
c_void_p
()
comm
=
NcclComm
()
check_error
(
lib
.
ncclCommInitRank
(
check_error
(
lib
.
ncclCommInitRank
(
ctypes
.
byref
(
comm
),
ctypes
.
byref
(
comm
),
ctypes
.
c_int
(
_num_devices
),
ctypes
.
c_int
(
_num_devices
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment