Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
208d9670
Unverified
Commit
208d9670
authored
Jul 21, 2022
by
Thor Johnsen
Committed by
GitHub
Jul 21, 2022
Browse files
Merge pull request #1429 from NVIDIA/update_spatial_bottleneck
Bug fixes, perf improvements
parents
a29a698f
f687e7fa
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
139 additions
and
138 deletions
+139
-138
apex/contrib/bottleneck/bottleneck_module_test.py
apex/contrib/bottleneck/bottleneck_module_test.py
+18
-14
apex/contrib/bottleneck/halo_exchangers.py
apex/contrib/bottleneck/halo_exchangers.py
+38
-33
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
+0
-2
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
+55
-65
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
+6
-19
apex/contrib/peer_memory/peer_memory.py
apex/contrib/peer_memory/peer_memory.py
+22
-5
No files found.
apex/contrib/bottleneck/bottleneck_module_test.py
View file @
208d9670
...
@@ -152,16 +152,6 @@ def spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_
...
@@ -152,16 +152,6 @@ def spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_
sb
[
n
].
copy_
(
b
)
sb
[
n
].
copy_
(
b
)
return
spatial_bottleneck
return
spatial_bottleneck
#class HaloExchangerNoComm(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerAllGather(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerSendRecv(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerPeer(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm, peer_pool, explicit_nhwc, numSM=1):
def
n_way_spatial
(
halex
,
gt_bottleneck
,
gt
,
explicit_nhwc
,
world_size
,
rank
,
fp32_reduce
=
False
):
def
n_way_spatial
(
halex
,
gt_bottleneck
,
gt
,
explicit_nhwc
,
world_size
,
rank
,
fp32_reduce
=
False
):
assert
(
explicit_nhwc
),
"Only tested for explicit nhwc"
assert
(
explicit_nhwc
),
"Only tested for explicit nhwc"
...
@@ -228,15 +218,29 @@ def main():
...
@@ -228,15 +218,29 @@ def main():
#print_bottleneck_p_and_b(gt_bottleneck)
#print_bottleneck_p_and_b(gt_bottleneck)
#print_bottleneck_p_and_b(spatial_bottleneck)
#print_bottleneck_p_and_b(spatial_bottleneck)
group_size
=
world_size
group
=
rank
//
group_size
ranks
=
[
group
*
group_size
+
i
for
i
in
range
(
group_size
)]
rank_in_group
=
rank
%
group_size
spatial_group_size
=
world_size
spatial_group_size
=
world_size
spatial_communicator
=
None
spatial_communicator
=
None
peer_pool
=
PeerMemoryPool
(
rank
,
world_size
,
spatial_group_size
,
64
*
1024
*
1024
,
2
*
1024
*
1024
)
peer_pool
=
PeerMemoryPool
(
64
*
1024
*
1024
,
2
*
1024
*
1024
,
ranks
)
#class HaloExchangerNoComm(HaloExchanger):
# def __init__(self, ranks, rank_in_group):
#class HaloExchangerAllGather(HaloExchanger):
# def __init__(self, ranks, rank_in_group, comm):
#class HaloExchangerSendRecv(HaloExchanger):
# def __init__(self, ranks, rank_in_group):
#class HaloExchangerPeer(HaloExchanger):
# def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1):
#halex = HaloExchangerAllGather(
world_size, spatial_group_size, rank, spatial_communicator
)
#halex = HaloExchangerAllGather(
ranks, rank_in_group
)
#halex = HaloExchangerSendRecv(
world_size, spatial_group_size, rank, spatial_communicator
)
#halex = HaloExchangerSendRecv(
ranks, rank_in_group
)
halex
=
HaloExchangerPeer
(
world_size
,
spatial_group_size
,
rank
,
spatial_communicator
,
peer_pool
,
explicit_nhwc
,
numSM
=
1
)
halex
=
HaloExchangerPeer
(
ranks
,
rank_in_group
,
peer_pool
,
explicit_nhwc
,
numSM
=
1
)
#print("halex.signals = %s" % (str(halex.signals)))
#print("halex.signals = %s" % (str(halex.signals)))
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
#torch.cuda.synchronize()
#torch.cuda.synchronize()
...
...
apex/contrib/bottleneck/halo_exchangers.py
View file @
208d9670
...
@@ -9,17 +9,23 @@ import peer_memory_cuda as pm
...
@@ -9,17 +9,23 @@ import peer_memory_cuda as pm
# NB! This is only useful for performance testing.
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
# NB! Do not use for actual production runs
class
HaloExchanger
(
object
):
class
HaloExchanger
(
object
):
def
__init__
(
self
,
spatial_group_size
,
rank
):
def
__init__
(
self
,
ranks
,
rank_in_group
):
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream2
=
torch
.
cuda
.
Stream
()
self
.
stream2
=
torch
.
cuda
.
Stream
()
self
.
stream3
=
torch
.
cuda
.
Stream
()
self
.
stream3
=
torch
.
cuda
.
Stream
()
spatial_rank
=
rank
%
spatial_group_size
self
.
group_size
=
len
(
ranks
)
self
.
left_zero
=
True
if
spatial_rank
==
0
else
False
self
.
ranks
=
ranks
self
.
right_zero
=
True
if
spatial_rank
==
spatial_group_size
-
1
else
False
self
.
rank_in_group
=
rank_in_group
self
.
wrap_around_left_rank_in_group
=
(
rank_in_group
+
self
.
group_size
-
1
)
%
self
.
group_size
self
.
wrap_around_right_rank_in_group
=
(
rank_in_group
+
1
)
%
self
.
group_size
self
.
left_rank
=
ranks
[
rank_in_group
-
1
]
if
rank_in_group
>
0
else
-
1
self
.
left_zero
=
True
if
rank_in_group
==
0
else
False
self
.
right_rank
=
ranks
[
rank_in_group
+
1
]
if
rank_in_group
<
self
.
group_size
-
1
else
-
1
self
.
right_zero
=
True
if
rank_in_group
==
self
.
group_size
-
1
else
False
class
HaloExchangerNoComm
(
HaloExchanger
):
class
HaloExchangerNoComm
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
def
__init__
(
self
,
ranks
,
rank_in_group
):
super
(
HaloExchangerNoComm
,
self
).
__init__
(
spatial_group_size
,
rank
)
super
(
HaloExchangerNoComm
,
self
).
__init__
(
ranks
,
rank_in_group
)
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
if
left_input_halo
is
None
:
if
left_input_halo
is
None
:
...
@@ -29,10 +35,9 @@ class HaloExchangerNoComm(HaloExchanger):
...
@@ -29,10 +35,9 @@ class HaloExchangerNoComm(HaloExchanger):
right_input_halo
.
copy_
(
left_output_halo
)
right_input_halo
.
copy_
(
left_output_halo
)
class
HaloExchangerAllGather
(
HaloExchanger
):
class
HaloExchangerAllGather
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
def
__init__
(
self
,
ranks
,
rank_in_group
,
comm
):
super
(
HaloExchangerAllGather
,
self
).
__init__
(
spatial_group_size
,
rank
)
super
(
HaloExchangerAllGather
,
self
).
__init__
(
ranks
,
rank_in_group
)
self
.
spatial_group_size
=
spatial_group_size
# self.comm must be NCCL process_group created with torch.distributed.new_group(ranks=ranks)
self
.
local_rank
=
rank
%
spatial_group_size
self
.
comm
=
comm
self
.
comm
=
comm
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
...
@@ -40,11 +45,11 @@ class HaloExchangerAllGather(HaloExchanger):
...
@@ -40,11 +45,11 @@ class HaloExchangerAllGather(HaloExchanger):
send_halos
=
torch
.
empty
((
N
,
2
*
Hh
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
send_halos
=
torch
.
empty
((
N
,
2
*
Hh
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
send_halos
[:,:
Hh
,:,:].
copy_
(
left_output_halo
)
send_halos
[:,:
Hh
,:,:].
copy_
(
left_output_halo
)
send_halos
[:,
Hh
:,:,:].
copy_
(
right_output_halo
)
send_halos
[:,
Hh
:,:,:].
copy_
(
right_output_halo
)
all_halos
=
torch
.
empty
((
N
,
2
*
Hh
*
self
.
spatial_
group_size
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
all_halos
=
torch
.
empty
((
N
,
2
*
Hh
*
self
.
group_size
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
*
Hh
:(
i
+
1
)
*
2
*
Hh
,:,:]
for
i
in
range
(
self
.
spatial_
group_size
)]
all_halos
=
[
all_halos
[:,
i
*
2
*
Hh
:(
i
+
1
)
*
2
*
Hh
,:,:]
for
i
in
range
(
self
.
group_size
)]
torch
.
distributed
.
all_gather
(
all_halos
,
send_halos
,
group
=
self
.
comm
,
no_copy
=
True
)
torch
.
distributed
.
all_gather
(
all_halos
,
send_halos
,
group
=
self
.
comm
,
no_copy
=
True
)
ag_left_input_halo
=
all_halos
[
(
self
.
spatial_group_size
+
self
.
local_rank
-
1
)
%
self
.
spatial
_group
_size
][:,
Hh
:,:,:]
ag_left_input_halo
=
all_halos
[
self
.
wrap_around_left_rank_in
_group
][:,
Hh
:,:,:]
ag_right_input_halo
=
all_halos
[
(
self
.
local_rank
+
1
)
%
self
.
spatial
_group
_size
][:,:
Hh
,:,:]
ag_right_input_halo
=
all_halos
[
self
.
wrap_around_right_rank_in
_group
][:,:
Hh
,:,:]
if
left_input_halo
is
None
:
if
left_input_halo
is
None
:
if
self
.
left_zero
:
if
self
.
left_zero
:
ag_left_input_halo
.
zero_
()
ag_left_input_halo
.
zero_
()
...
@@ -62,35 +67,35 @@ class HaloExchangerAllGather(HaloExchanger):
...
@@ -62,35 +67,35 @@ class HaloExchangerAllGather(HaloExchanger):
right_input_halo
.
copy_
(
ag_right_input_halo
)
right_input_halo
.
copy_
(
ag_right_input_halo
)
class
HaloExchangerSendRecv
(
HaloExchanger
):
class
HaloExchangerSendRecv
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
def
__init__
(
self
,
ranks
,
rank_in_group
):
super
(
HaloExchangerSendRecv
,
self
).
__init__
(
spatial_group_size
,
rank
)
super
(
HaloExchangerSendRecv
,
self
).
__init__
(
ranks
,
rank_in_group
)
self
.
world_size
=
world_size
self
.
spatial_group_size
=
spatial_group_size
nccl_id
=
inc
.
get_unique_nccl_id
(
1
).
cuda
()
nccl_id
=
inc
.
get_unique_nccl_id
(
1
).
cuda
()
torch
.
distributed
.
broadcast
(
nccl_id
,
0
)
torch
.
distributed
.
broadcast
(
nccl_id
,
0
)
nccl_id
=
nccl_id
.
cpu
()
nccl_id
=
nccl_id
.
cpu
()
self
.
handle
=
inc
.
init_nccl_comm
(
nccl_id
,
rank
,
world_size
)
print
(
"%d :: nccl_id = %s"
%
(
torch
.
distributed
.
get_rank
(),
str
(
nccl_id
)))
# Create another global nccl communicator in addition to the one created by torch.distributed.init_process_group("nccl")
# This is unavoidable because the underlying NCCL communicator torch.distributed creates is a protected variable, hence
# it cannot be accessed from another class.
# TODO: Figure out a way to avoid creating a second global communicator
assert
(
torch
.
distributed
.
get_rank
()
==
self
.
ranks
[
self
.
rank_in_group
]),
"ranks[%d](%d) != torch.distributed.get_rank()(%d)"
%
(
self
.
rank_in_group
,
self
.
ranks
[
self
.
rank_in_group
],
torch
.
distributed
.
get_rank
())
self
.
handle
=
inc
.
init_nccl_comm
(
nccl_id
,
torch
.
distributed
.
get_rank
(),
torch
.
distributed
.
get_world_size
())
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
if
left_input_halo
is
None
:
if
left_input_halo
is
None
:
left_input_halo
,
right_input_halo
=
inc
.
left_right_halo_exchange
(
self
.
handle
,
self
.
left_
zero
,
self
.
right_
zero
,
left_output_halo
,
right_output_halo
,
self
.
spatial_group_size
)
left_input_halo
,
right_input_halo
=
inc
.
left_right_halo_exchange
(
self
.
handle
,
self
.
left_
rank
,
self
.
right_
rank
,
left_output_halo
,
right_output_halo
)
return
left_input_halo
,
right_input_halo
return
left_input_halo
,
right_input_halo
else
:
else
:
inc
.
left_right_halo_exchange_inplace
(
self
.
handle
,
self
.
left_
zero
,
self
.
right_
zero
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
self
.
spatial_group_size
)
inc
.
left_right_halo_exchange_inplace
(
self
.
handle
,
self
.
left_
rank
,
self
.
right_
rank
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
)
class
HaloExchangerPeer
(
HaloExchanger
):
class
HaloExchangerPeer
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
,
peer_pool
,
explicit_nhwc
,
numSM
=
1
):
def
__init__
(
self
,
ranks
,
rank_in_group
,
peer_pool
,
explicit_nhwc
,
numSM
=
1
):
super
(
HaloExchangerPeer
,
self
).
__init__
(
spatial_group_size
,
rank
)
super
(
HaloExchangerPeer
,
self
).
__init__
(
ranks
,
rank_in_group
)
self
.
diagnostics
=
False
self
.
diagnostics
=
False
self
.
spatial_group_size
=
spatial_group_size
self
.
peer_rank
=
rank
%
spatial_group_size
self
.
left_neighbor
=
(
self
.
peer_rank
+
self
.
spatial_group_size
-
1
)
%
self
.
spatial_group_size
self
.
right_neighbor
=
(
self
.
peer_rank
+
1
)
%
self
.
spatial_group_size
self
.
peer_pool
=
peer_pool
self
.
signals
=
peer_pool
.
allocate_peer_tensors
([
2
,
4
],
torch
.
int32
,
False
,
False
)
self
.
signals
[
self
.
peer_rank
].
zero_
()
self
.
explicit_nhwc
=
explicit_nhwc
self
.
explicit_nhwc
=
explicit_nhwc
self
.
numSM
=
numSM
self
.
numSM
=
numSM
self
.
peer_pool
=
peer_pool
self
.
signals
=
peer_pool
.
allocate_peer_tensors
([
2
,
4
],
torch
.
int32
,
False
,
False
)
self
.
signals
[
self
.
rank_in_group
].
zero_
()
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
inplace
=
False
if
left_input_halo
is
None
and
right_input_halo
is
None
else
True
inplace
=
False
if
left_input_halo
is
None
and
right_input_halo
is
None
else
True
...
@@ -102,9 +107,9 @@ class HaloExchangerPeer(HaloExchanger):
...
@@ -102,9 +107,9 @@ class HaloExchangerPeer(HaloExchanger):
right_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
right_output_halo
.
shape
),
right_output_halo
.
dtype
,
channels_last
,
True
)
right_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
right_output_halo
.
shape
),
right_output_halo
.
dtype
,
channels_last
,
True
)
pm
.
push_pull_halos_1d
(
pm
.
push_pull_halos_1d
(
self
.
diagnostics
,
self
.
explicit_nhwc
,
self
.
numSM
,
self
.
diagnostics
,
self
.
explicit_nhwc
,
self
.
numSM
,
left_output_halo
,
left_tx
[
self
.
peer_rank
],
right_tx
[
self
.
left_neighbor
],
left_input_halo
,
left_output_halo
,
left_tx
[
self
.
rank_in_group
],
right_tx
[
self
.
wrap_around_left_rank_in_group
],
left_input_halo
,
right_output_halo
,
right_tx
[
self
.
peer_rank
],
left_tx
[
self
.
right_neighbor
],
right_input_halo
,
right_output_halo
,
right_tx
[
self
.
rank_in_group
],
left_tx
[
self
.
wrap_around_right_rank_in_group
],
right_input_halo
,
self
.
signals
[
self
.
left_neighbor
],
self
.
signals
[
self
.
right_neighbor
],
self
.
signals
[
self
.
peer_rank
]
self
.
signals
[
self
.
wrap_around_left_rank_in_group
],
self
.
signals
[
self
.
wrap_around_right_rank_in_group
],
self
.
signals
[
self
.
rank_in_group
]
)
)
# TODO: Add to push_pull_halos_1d kernel
# TODO: Add to push_pull_halos_1d kernel
if
self
.
left_zero
:
if
self
.
left_zero
:
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
View file @
208d9670
...
@@ -19,8 +19,6 @@
...
@@ -19,8 +19,6 @@
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"get_unique_nccl_id"
,
&
apex
::
contrib
::
nccl_p2p
::
get_unique_nccl_id
,
"get_unique_nccl_id"
);
m
.
def
(
"get_unique_nccl_id"
,
&
apex
::
contrib
::
nccl_p2p
::
get_unique_nccl_id
,
"get_unique_nccl_id"
);
m
.
def
(
"init_nccl_comm"
,
&
apex
::
contrib
::
nccl_p2p
::
init_nccl_comm
,
"init_nccl_comm"
);
m
.
def
(
"init_nccl_comm"
,
&
apex
::
contrib
::
nccl_p2p
::
init_nccl_comm
,
"init_nccl_comm"
);
m
.
def
(
"nccl_send"
,
&
apex
::
contrib
::
nccl_p2p
::
nccl_send
,
"nccl_send"
);
m
.
def
(
"nccl_recv"
,
&
apex
::
contrib
::
nccl_p2p
::
nccl_recv
,
"nccl_recv"
);
m
.
def
(
"left_right_halo_exchange_inplace"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange_inplace
,
"left_right_halo_exchange_inplace"
);
m
.
def
(
"left_right_halo_exchange_inplace"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange_inplace
,
"left_right_halo_exchange_inplace"
);
m
.
def
(
"left_right_halo_exchange"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange
,
"left_right_halo_exchange"
);
m
.
def
(
"left_right_halo_exchange"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange
,
"left_right_halo_exchange"
);
m
.
def
(
"add_delay"
,
&
apex
::
contrib
::
nccl_p2p
::
add_delay
,
"add_delay"
);
m
.
def
(
"add_delay"
,
&
apex
::
contrib
::
nccl_p2p
::
add_delay
,
"add_delay"
);
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
View file @
208d9670
...
@@ -80,75 +80,82 @@ class NcclCommWrapper
...
@@ -80,75 +80,82 @@ class NcclCommWrapper
ncclCommDestroy
(
comm
);
ncclCommDestroy
(
comm
);
}
}
void
send
(
at
::
Tensor
input
,
int
destination
)
void
left_right_halo_exchange_inplace
(
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
)
{
ncclDataType_t
ncclType
=
get_nccl_type
(
input
);
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
input
.
scalar_type
(),
"nccl_send"
,
[
&
]()
{
size_t
count
=
sizeof
(
scalar_t
)
*
torch
::
numel
(
input
);
auto
input_ptr
=
input
.
data_ptr
<
scalar_t
>
();
ncclSend
(
input_ptr
,
count
,
ncclType
,
destination
,
comm
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
void
recv
(
at
::
Tensor
input
,
int
sender
)
{
ncclDataType_t
ncclType
=
get_nccl_type
(
input
);
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
input
.
scalar_type
(),
"nccl_send"
,
[
&
]()
{
size_t
count
=
sizeof
(
scalar_t
)
*
torch
::
numel
(
input
);
auto
input_ptr
=
input
.
data_ptr
<
scalar_t
>
();
ncclRecv
(
input_ptr
,
count
,
ncclType
,
sender
,
comm
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
void
left_right_halo_exchange_inplace
(
bool
left_zero
,
bool
right_zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
,
int
group_size
)
{
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
ncclGroupStart
();
ncclGroupStart
();
ncclDataType_t
ncclType
=
get_nccl_type
(
left_output_halo
);
ncclDataType_t
ncclType
=
get_nccl_type
(
left_output_halo
);
// we use wrap-around ranks, so left_input_halo of rank 0 has right_output_halo of rank world_size-1 after exchange etc.
bool
left_zero
=
(
left_rank
<
0
);
// this is technically speaking wasteful, but there is no benefit in having the edge ranks do less work than internal ranks.
bool
right_zero
=
(
right_rank
<
0
);
int
group_rank
=
rank
%
group_size
;
int
group_index
=
rank
/
group_size
;
int
prev_rank
=
(
group_rank
+
group_size
-
1
)
%
group_size
;
int
next_rank
=
(
group_rank
+
1
)
%
group_size
;
prev_rank
=
prev_rank
+
group_index
*
group_size
;
next_rank
=
next_rank
+
group_index
*
group_size
;
size_t
left_n
=
torch
::
numel
(
left_output_halo
);
size_t
left_n
=
torch
::
numel
(
left_output_halo
);
size_t
right_n
=
torch
::
numel
(
right_output_halo
);
size_t
right_n
=
torch
::
numel
(
right_output_halo
);
if
(
group_rank
>
0
)
{
assert
(
left_n
>
0
&&
left_n
==
right_n
);
if
(
left_zero
)
{
left_input_halo
.
zero_
();
}
else
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
left_output_halo
.
scalar_type
(),
"left_halo_exch"
,
[
&
]()
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
left_output_halo
.
scalar_type
(),
"left_halo_exch"
,
[
&
]()
{
// send left (to my_rank - 1)
// send left (to my_rank - 1)
ncclSend
(
left_output_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
prev
_rank
,
comm
,
stream
);
ncclSend
(
left_output_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
left
_rank
,
comm
,
stream
);
// receive left (from my_rank - 1)
// receive left (from my_rank - 1)
ncclRecv
(
left_input_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
prev
_rank
,
comm
,
stream
);
ncclRecv
(
left_input_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
left
_rank
,
comm
,
stream
);
});
});
}
}
if
(
group_rank
<
group_size
-
1
)
{
if
(
right_zero
)
{
right_input_halo
.
zero_
();
}
else
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
right_output_halo
.
scalar_type
(),
"right_halo_exch"
,
[
&
]()
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
right_output_halo
.
scalar_type
(),
"right_halo_exch"
,
[
&
]()
{
// send right (to my_rank + 1 )
// send right (to my_rank + 1 )
ncclSend
(
right_output_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
nex
t_rank
,
comm
,
stream
);
ncclSend
(
right_output_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
righ
t_rank
,
comm
,
stream
);
// receive right (from my_rank + 1)
// receive right (from my_rank + 1)
ncclRecv
(
right_input_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
nex
t_rank
,
comm
,
stream
);
ncclRecv
(
right_input_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
righ
t_rank
,
comm
,
stream
);
});
});
}
}
ncclGroupEnd
();
ncclGroupEnd
();
if
(
left_zero
)
left_input_halo
.
zero_
();
if
(
right_zero
)
right_input_halo
.
zero_
();
}
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
bool
left_
zero
,
bool
right_
zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
)
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
left_
rank
,
int
right_
rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
)
{
{
// after halo exchange:
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto
right_input_halo
=
torch
::
empty_like
(
left_output_halo
);
auto
right_input_halo
=
torch
::
empty_like
(
left_output_halo
);
auto
left_input_halo
=
torch
::
empty_like
(
right_output_halo
);
auto
left_input_halo
=
torch
::
empty_like
(
right_output_halo
);
left_right_halo_exchange_inplace
(
left_
zero
,
right_
zero
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
group_size
);
left_right_halo_exchange_inplace
(
left_
rank
,
right_
rank
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
);
return
{
left_input_halo
,
right_input_halo
};
return
{
left_input_halo
,
right_input_halo
};
}
}
};
};
std
::
vector
<
NcclCommWrapper
>
nccl_comms
;
class
ManagedObjects
{
public:
ManagedObjects
()
{
}
~
ManagedObjects
()
{
for
(
auto
it
=
_nccl_comms
.
begin
();
it
!=
_nccl_comms
.
end
();
++
it
)
{
delete
*
it
;
}
}
int
add_comm
(
NcclCommWrapper
*
comm
)
{
int
handle
=
_nccl_comms
.
size
();
_nccl_comms
.
push_back
(
comm
);
return
handle
;
}
NcclCommWrapper
&
get_comm
(
int
handle
)
{
assert
(
handle
>=
0
&&
handle
<
_nccl_comms
.
size
());
return
*
_nccl_comms
[
handle
];
}
private:
std
::
vector
<
NcclCommWrapper
*>
_nccl_comms
;
};
class
ManagedObjects
mo
;
}
// end anonymous namespace
}
// end anonymous namespace
...
@@ -158,7 +165,7 @@ at::Tensor get_unique_nccl_id(int n)
...
@@ -158,7 +165,7 @@ at::Tensor get_unique_nccl_id(int n)
{
{
ncclUniqueId
id
;
ncclUniqueId
id
;
ncclGetUniqueId
(
&
id
);
ncclGetUniqueId
(
&
id
);
auto
id_tensor
=
torch
::
empty
({
n
*
(
int
)
sizeof
(
ncclUniqueId
)},
torch
::
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
).
requires_grad
(
false
));
auto
id_tensor
=
torch
::
empty
({
n
,
(
int
)
sizeof
(
ncclUniqueId
)},
torch
::
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
).
requires_grad
(
false
));
auto
id_ptr
=
id_tensor
.
data_ptr
<
uint8_t
>
();
auto
id_ptr
=
id_tensor
.
data_ptr
<
uint8_t
>
();
size_t
offset
=
0
;
size_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
for
(
int
i
=
0
;
i
<
n
;
++
i
)
...
@@ -177,38 +184,21 @@ int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks)
...
@@ -177,38 +184,21 @@ int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks)
auto
unique_nccl_id_ptr
=
unique_nccl_id
.
data_ptr
<
uint8_t
>
();
auto
unique_nccl_id_ptr
=
unique_nccl_id
.
data_ptr
<
uint8_t
>
();
memcpy
(
&
id
,
unique_nccl_id_ptr
,
sizeof
(
ncclUniqueId
));
memcpy
(
&
id
,
unique_nccl_id_ptr
,
sizeof
(
ncclUniqueId
));
NcclCommWrapper
*
comm
=
new
NcclCommWrapper
(
id
,
my_rank
,
num_ranks
);
NcclCommWrapper
*
comm
=
new
NcclCommWrapper
(
id
,
my_rank
,
num_ranks
);
int
handle
=
nccl_comms
.
size
();
int
handle
=
mo
.
add_comm
(
comm
);
nccl_comms
.
push_back
(
*
comm
);
comm
=
0L
;
comm
=
0L
;
return
handle
;
return
handle
;
}
}
void
nccl_send
(
int
handle
,
at
::
Tensor
input
,
int
destination
)
void
left_right_halo_exchange_inplace
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
)
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
communicator
=
nccl_comms
[
handle
];
communicator
.
send
(
input
,
destination
);
}
void
nccl_recv
(
int
handle
,
at
::
Tensor
input
,
int
sender
)
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
communicator
=
nccl_comms
[
handle
];
communicator
.
recv
(
input
,
sender
);
}
void
left_right_halo_exchange_inplace
(
int
handle
,
bool
left_zero
,
bool
right_zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
,
int
group_size
)
{
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
&
communicator
=
mo
.
get_comm
(
handle
);
class
NcclCommWrapper
&
communicator
=
nccl_comms
[
handle
];
return
communicator
.
left_right_halo_exchange_inplace
(
left_rank
,
right_rank
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
);
return
communicator
.
left_right_halo_exchange_inplace
(
left_zero
,
right_zero
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
group_size
);
}
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
bool
left_
zero
,
bool
right_
zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
)
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
int
left_
rank
,
int
right_
rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
)
{
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
&
communicator
=
mo
.
get_comm
(
handle
);
class
NcclCommWrapper
&
communicator
=
nccl_comms
[
handle
];
return
communicator
.
left_right_halo_exchange
(
left_rank
,
right_rank
,
left_output_halo
,
right_output_halo
);
return
communicator
.
left_right_halo_exchange
(
left_zero
,
right_zero
,
left_output_halo
,
right_output_halo
,
group_size
);
}
}
void
add_delay
(
int
delay
)
void
add_delay
(
int
delay
)
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
View file @
208d9670
...
@@ -26,33 +26,20 @@ int init_nccl_comm(
...
@@ -26,33 +26,20 @@ int init_nccl_comm(
int
my_rank
,
int
my_rank
,
int
num_ranks
int
num_ranks
);
);
void
nccl_send
(
int
handle
,
at
::
Tensor
input
,
int
destination
);
void
nccl_recv
(
int
handle
,
at
::
Tensor
input
,
int
sender
);
void
left_right_halo_exchange_inplace
(
void
left_right_halo_exchange_inplace
(
int
handle
,
int
handle
,
bool
left_
zero
,
int
left_
rank
,
bool
right_
zero
,
int
right_
rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
,
at
::
Tensor
right_input_halo
);
int
group_size
);
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
int
handle
,
bool
left_
zero
,
int
left_
rank
,
bool
right_
zero
,
int
right_
rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
right_output_halo
);
int
group_size
);
void
add_delay
(
int
delay
);
void
add_delay
(
int
delay
);
}}}
}}}
#endif
#endif
apex/contrib/peer_memory/peer_memory.py
View file @
208d9670
...
@@ -4,23 +4,40 @@ import peer_memory_cuda as pm
...
@@ -4,23 +4,40 @@ import peer_memory_cuda as pm
class
PeerMemoryPool
(
object
):
class
PeerMemoryPool
(
object
):
def
__init__
(
self
,
rank
,
world_size
,
peer_group_size
,
static_size
,
dynamic_size
):
def
__init__
(
self
,
static_size
,
dynamic_size
,
peer_ranks
=
None
):
self
.
peer_group
=
rank
//
peer_group_size
rank
=
torch
.
distributed
.
get_rank
()
self
.
peer_rank
=
rank
%
peer_group_size
world_size
=
torch
.
distributed
.
get_world_size
()
self
.
peer_group_size
=
peer_group_size
ngpus
=
min
(
torch
.
cuda
.
device_count
(),
world_size
)
peer_group_size
=
ngpus
peer_group
=
rank
//
ngpus
peer_rank_base
=
peer_group
*
ngpus
peer_rank
=
rank
-
peer_rank_base
if
peer_ranks
is
None
:
peer_ranks
=
[
i
+
peer_rank_base
for
i
in
range
(
peer_group_size
)]
peer_rank_start
=
peer_rank_base
peer_rank_end
=
peer_rank_start
+
peer_group_size
-
1
for
pr
in
peer_ranks
:
assert
(
pr
>=
peer_rank_start
and
pr
<=
peer_rank_end
),
"%d :: peer_rank %d not on same node (ranks=[%d,%d])"
%
(
rank
,
pr
,
peer_rank_start
,
peer_rank_end
)
self
.
alignment
=
256
self
.
alignment
=
256
self
.
static_size
=
((
static_size
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
static_size
=
((
static_size
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
dynamic_size
=
((
dynamic_size
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
dynamic_size
=
((
dynamic_size
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
# allocate giant pool of device memory
# allocate giant pool of device memory
self
.
raw
=
pm
.
allocate_raw
(
self
.
static_size
+
self
.
dynamic_size
)
self
.
raw
=
pm
.
allocate_raw
(
self
.
static_size
+
self
.
dynamic_size
)
# exchange peer pointers with nccl
# exchange peer pointers with nccl
raw_ipc
=
pm
.
get_raw_ipc_address
(
self
.
raw
).
cuda
()
raw_ipc
=
pm
.
get_raw_ipc_address
(
self
.
raw
).
cuda
()
peer_raw_ipcs
=
[
torch
.
empty_like
(
raw_ipc
)
for
_
in
range
(
world_size
)]
peer_raw_ipcs
=
[
torch
.
empty_like
(
raw_ipc
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
peer_raw_ipcs
,
raw_ipc
)
torch
.
distributed
.
all_gather
(
peer_raw_ipcs
,
raw_ipc
)
peer_raw_ipcs
=
torch
.
stack
(
peer_raw_ipcs
).
cpu
()
peer_raw_ipcs
=
torch
.
stack
(
peer_raw_ipcs
).
cpu
()
self
.
peer_raw
=
pm
.
get_raw_peers
(
peer_raw_ipcs
,
self
.
peer_rank
,
self
.
raw
)
# extract IPC pointers for ranks on same node
peer_raw
=
pm
.
get_raw_peers
(
peer_raw_ipcs
[
peer_rank_base
:
peer_rank_base
+
ngpus
],
peer_rank
,
self
.
raw
)
self
.
peer_raw
=
[
peer_raw
[
peer_rank
-
peer_rank_base
]
for
peer_rank
in
peer_ranks
]
self
.
static_offset
=
0
self
.
static_offset
=
0
self
.
dynamic_offset
=
0
self
.
dynamic_offset
=
0
self
.
peer_ranks
=
peer_ranks
def
__del__
(
self
):
def
__del__
(
self
):
pm
.
free_raw
(
self
.
raw
)
pm
.
free_raw
(
self
.
raw
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment