Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
f687e7fa
Commit
f687e7fa
authored
Jul 21, 2022
by
Thor Johnsen
Browse files
Bug fixes, perf improvements
parent
a29a698f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
139 additions
and
138 deletions
+139
-138
apex/contrib/bottleneck/bottleneck_module_test.py
apex/contrib/bottleneck/bottleneck_module_test.py
+18
-14
apex/contrib/bottleneck/halo_exchangers.py
apex/contrib/bottleneck/halo_exchangers.py
+38
-33
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
+0
-2
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
+55
-65
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
+6
-19
apex/contrib/peer_memory/peer_memory.py
apex/contrib/peer_memory/peer_memory.py
+22
-5
No files found.
apex/contrib/bottleneck/bottleneck_module_test.py
View file @
f687e7fa
...
@@ -152,16 +152,6 @@ def spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_
...
@@ -152,16 +152,6 @@ def spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_
sb
[
n
].
copy_
(
b
)
sb
[
n
].
copy_
(
b
)
return
spatial_bottleneck
return
spatial_bottleneck
#class HaloExchangerNoComm(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerAllGather(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerSendRecv(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerPeer(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm, peer_pool, explicit_nhwc, numSM=1):
def
n_way_spatial
(
halex
,
gt_bottleneck
,
gt
,
explicit_nhwc
,
world_size
,
rank
,
fp32_reduce
=
False
):
def
n_way_spatial
(
halex
,
gt_bottleneck
,
gt
,
explicit_nhwc
,
world_size
,
rank
,
fp32_reduce
=
False
):
assert
(
explicit_nhwc
),
"Only tested for explicit nhwc"
assert
(
explicit_nhwc
),
"Only tested for explicit nhwc"
...
@@ -228,15 +218,29 @@ def main():
...
@@ -228,15 +218,29 @@ def main():
#print_bottleneck_p_and_b(gt_bottleneck)
#print_bottleneck_p_and_b(gt_bottleneck)
#print_bottleneck_p_and_b(spatial_bottleneck)
#print_bottleneck_p_and_b(spatial_bottleneck)
group_size
=
world_size
group
=
rank
//
group_size
ranks
=
[
group
*
group_size
+
i
for
i
in
range
(
group_size
)]
rank_in_group
=
rank
%
group_size
spatial_group_size
=
world_size
spatial_group_size
=
world_size
spatial_communicator
=
None
spatial_communicator
=
None
peer_pool
=
PeerMemoryPool
(
rank
,
world_size
,
spatial_group_size
,
64
*
1024
*
1024
,
2
*
1024
*
1024
)
peer_pool
=
PeerMemoryPool
(
64
*
1024
*
1024
,
2
*
1024
*
1024
,
ranks
)
#class HaloExchangerNoComm(HaloExchanger):
# def __init__(self, ranks, rank_in_group):
#class HaloExchangerAllGather(HaloExchanger):
# def __init__(self, ranks, rank_in_group, comm):
#class HaloExchangerSendRecv(HaloExchanger):
# def __init__(self, ranks, rank_in_group):
#class HaloExchangerPeer(HaloExchanger):
# def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1):
#halex = HaloExchangerAllGather(
world_size, spatial_group_size, rank, spatial_communicator
)
#halex = HaloExchangerAllGather(
ranks, rank_in_group
)
#halex = HaloExchangerSendRecv(
world_size, spatial_group_size, rank, spatial_communicator
)
#halex = HaloExchangerSendRecv(
ranks, rank_in_group
)
halex
=
HaloExchangerPeer
(
world_size
,
spatial_group_size
,
rank
,
spatial_communicator
,
peer_pool
,
explicit_nhwc
,
numSM
=
1
)
halex
=
HaloExchangerPeer
(
ranks
,
rank_in_group
,
peer_pool
,
explicit_nhwc
,
numSM
=
1
)
#print("halex.signals = %s" % (str(halex.signals)))
#print("halex.signals = %s" % (str(halex.signals)))
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
#torch.cuda.synchronize()
#torch.cuda.synchronize()
...
...
apex/contrib/bottleneck/halo_exchangers.py
View file @
f687e7fa
...
@@ -9,17 +9,23 @@ import peer_memory_cuda as pm
...
@@ -9,17 +9,23 @@ import peer_memory_cuda as pm
# NB! This is only useful for performance testing.
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
# NB! Do not use for actual production runs
class
HaloExchanger
(
object
):
class
HaloExchanger
(
object
):
def
__init__
(
self
,
spatial_group_size
,
rank
):
def
__init__
(
self
,
ranks
,
rank_in_group
):
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream2
=
torch
.
cuda
.
Stream
()
self
.
stream2
=
torch
.
cuda
.
Stream
()
self
.
stream3
=
torch
.
cuda
.
Stream
()
self
.
stream3
=
torch
.
cuda
.
Stream
()
spatial_rank
=
rank
%
spatial_group_size
self
.
group_size
=
len
(
ranks
)
self
.
left_zero
=
True
if
spatial_rank
==
0
else
False
self
.
ranks
=
ranks
self
.
right_zero
=
True
if
spatial_rank
==
spatial_group_size
-
1
else
False
self
.
rank_in_group
=
rank_in_group
self
.
wrap_around_left_rank_in_group
=
(
rank_in_group
+
self
.
group_size
-
1
)
%
self
.
group_size
self
.
wrap_around_right_rank_in_group
=
(
rank_in_group
+
1
)
%
self
.
group_size
self
.
left_rank
=
ranks
[
rank_in_group
-
1
]
if
rank_in_group
>
0
else
-
1
self
.
left_zero
=
True
if
rank_in_group
==
0
else
False
self
.
right_rank
=
ranks
[
rank_in_group
+
1
]
if
rank_in_group
<
self
.
group_size
-
1
else
-
1
self
.
right_zero
=
True
if
rank_in_group
==
self
.
group_size
-
1
else
False
class
HaloExchangerNoComm
(
HaloExchanger
):
class
HaloExchangerNoComm
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
def
__init__
(
self
,
ranks
,
rank_in_group
):
super
(
HaloExchangerNoComm
,
self
).
__init__
(
spatial_group_size
,
rank
)
super
(
HaloExchangerNoComm
,
self
).
__init__
(
ranks
,
rank_in_group
)
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
if
left_input_halo
is
None
:
if
left_input_halo
is
None
:
...
@@ -29,10 +35,9 @@ class HaloExchangerNoComm(HaloExchanger):
...
@@ -29,10 +35,9 @@ class HaloExchangerNoComm(HaloExchanger):
right_input_halo
.
copy_
(
left_output_halo
)
right_input_halo
.
copy_
(
left_output_halo
)
class
HaloExchangerAllGather
(
HaloExchanger
):
class
HaloExchangerAllGather
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
def
__init__
(
self
,
ranks
,
rank_in_group
,
comm
):
super
(
HaloExchangerAllGather
,
self
).
__init__
(
spatial_group_size
,
rank
)
super
(
HaloExchangerAllGather
,
self
).
__init__
(
ranks
,
rank_in_group
)
self
.
spatial_group_size
=
spatial_group_size
# self.comm must be NCCL process_group created with torch.distributed.new_group(ranks=ranks)
self
.
local_rank
=
rank
%
spatial_group_size
self
.
comm
=
comm
self
.
comm
=
comm
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
...
@@ -40,11 +45,11 @@ class HaloExchangerAllGather(HaloExchanger):
...
@@ -40,11 +45,11 @@ class HaloExchangerAllGather(HaloExchanger):
send_halos
=
torch
.
empty
((
N
,
2
*
Hh
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
send_halos
=
torch
.
empty
((
N
,
2
*
Hh
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
send_halos
[:,:
Hh
,:,:].
copy_
(
left_output_halo
)
send_halos
[:,:
Hh
,:,:].
copy_
(
left_output_halo
)
send_halos
[:,
Hh
:,:,:].
copy_
(
right_output_halo
)
send_halos
[:,
Hh
:,:,:].
copy_
(
right_output_halo
)
all_halos
=
torch
.
empty
((
N
,
2
*
Hh
*
self
.
spatial_
group_size
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
all_halos
=
torch
.
empty
((
N
,
2
*
Hh
*
self
.
group_size
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
*
Hh
:(
i
+
1
)
*
2
*
Hh
,:,:]
for
i
in
range
(
self
.
spatial_
group_size
)]
all_halos
=
[
all_halos
[:,
i
*
2
*
Hh
:(
i
+
1
)
*
2
*
Hh
,:,:]
for
i
in
range
(
self
.
group_size
)]
torch
.
distributed
.
all_gather
(
all_halos
,
send_halos
,
group
=
self
.
comm
,
no_copy
=
True
)
torch
.
distributed
.
all_gather
(
all_halos
,
send_halos
,
group
=
self
.
comm
,
no_copy
=
True
)
ag_left_input_halo
=
all_halos
[
(
self
.
spatial_group_size
+
self
.
local_rank
-
1
)
%
self
.
spatial
_group
_size
][:,
Hh
:,:,:]
ag_left_input_halo
=
all_halos
[
self
.
wrap_around_left_rank_in
_group
][:,
Hh
:,:,:]
ag_right_input_halo
=
all_halos
[
(
self
.
local_rank
+
1
)
%
self
.
spatial
_group
_size
][:,:
Hh
,:,:]
ag_right_input_halo
=
all_halos
[
self
.
wrap_around_right_rank_in
_group
][:,:
Hh
,:,:]
if
left_input_halo
is
None
:
if
left_input_halo
is
None
:
if
self
.
left_zero
:
if
self
.
left_zero
:
ag_left_input_halo
.
zero_
()
ag_left_input_halo
.
zero_
()
...
@@ -62,35 +67,35 @@ class HaloExchangerAllGather(HaloExchanger):
...
@@ -62,35 +67,35 @@ class HaloExchangerAllGather(HaloExchanger):
right_input_halo
.
copy_
(
ag_right_input_halo
)
right_input_halo
.
copy_
(
ag_right_input_halo
)
class
HaloExchangerSendRecv
(
HaloExchanger
):
class
HaloExchangerSendRecv
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
def
__init__
(
self
,
ranks
,
rank_in_group
):
super
(
HaloExchangerSendRecv
,
self
).
__init__
(
spatial_group_size
,
rank
)
super
(
HaloExchangerSendRecv
,
self
).
__init__
(
ranks
,
rank_in_group
)
self
.
world_size
=
world_size
self
.
spatial_group_size
=
spatial_group_size
nccl_id
=
inc
.
get_unique_nccl_id
(
1
).
cuda
()
nccl_id
=
inc
.
get_unique_nccl_id
(
1
).
cuda
()
torch
.
distributed
.
broadcast
(
nccl_id
,
0
)
torch
.
distributed
.
broadcast
(
nccl_id
,
0
)
nccl_id
=
nccl_id
.
cpu
()
nccl_id
=
nccl_id
.
cpu
()
self
.
handle
=
inc
.
init_nccl_comm
(
nccl_id
,
rank
,
world_size
)
print
(
"%d :: nccl_id = %s"
%
(
torch
.
distributed
.
get_rank
(),
str
(
nccl_id
)))
# Create another global nccl communicator in addition to the one created by torch.distributed.init_process_group("nccl")
# This is unavoidable because the underlying NCCL communicator torch.distributed creates is a protected variable, hence
# it cannot be accessed from another class.
# TODO: Figure out a way to avoid creating a second global communicator
assert
(
torch
.
distributed
.
get_rank
()
==
self
.
ranks
[
self
.
rank_in_group
]),
"ranks[%d](%d) != torch.distributed.get_rank()(%d)"
%
(
self
.
rank_in_group
,
self
.
ranks
[
self
.
rank_in_group
],
torch
.
distributed
.
get_rank
())
self
.
handle
=
inc
.
init_nccl_comm
(
nccl_id
,
torch
.
distributed
.
get_rank
(),
torch
.
distributed
.
get_world_size
())
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
if
left_input_halo
is
None
:
if
left_input_halo
is
None
:
left_input_halo
,
right_input_halo
=
inc
.
left_right_halo_exchange
(
self
.
handle
,
self
.
left_
zero
,
self
.
right_
zero
,
left_output_halo
,
right_output_halo
,
self
.
spatial_group_size
)
left_input_halo
,
right_input_halo
=
inc
.
left_right_halo_exchange
(
self
.
handle
,
self
.
left_
rank
,
self
.
right_
rank
,
left_output_halo
,
right_output_halo
)
return
left_input_halo
,
right_input_halo
return
left_input_halo
,
right_input_halo
else
:
else
:
inc
.
left_right_halo_exchange_inplace
(
self
.
handle
,
self
.
left_
zero
,
self
.
right_
zero
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
self
.
spatial_group_size
)
inc
.
left_right_halo_exchange_inplace
(
self
.
handle
,
self
.
left_
rank
,
self
.
right_
rank
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
)
class
HaloExchangerPeer
(
HaloExchanger
):
class
HaloExchangerPeer
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
,
peer_pool
,
explicit_nhwc
,
numSM
=
1
):
def
__init__
(
self
,
ranks
,
rank_in_group
,
peer_pool
,
explicit_nhwc
,
numSM
=
1
):
super
(
HaloExchangerPeer
,
self
).
__init__
(
spatial_group_size
,
rank
)
super
(
HaloExchangerPeer
,
self
).
__init__
(
ranks
,
rank_in_group
)
self
.
diagnostics
=
False
self
.
diagnostics
=
False
self
.
spatial_group_size
=
spatial_group_size
self
.
peer_rank
=
rank
%
spatial_group_size
self
.
left_neighbor
=
(
self
.
peer_rank
+
self
.
spatial_group_size
-
1
)
%
self
.
spatial_group_size
self
.
right_neighbor
=
(
self
.
peer_rank
+
1
)
%
self
.
spatial_group_size
self
.
peer_pool
=
peer_pool
self
.
signals
=
peer_pool
.
allocate_peer_tensors
([
2
,
4
],
torch
.
int32
,
False
,
False
)
self
.
signals
[
self
.
peer_rank
].
zero_
()
self
.
explicit_nhwc
=
explicit_nhwc
self
.
explicit_nhwc
=
explicit_nhwc
self
.
numSM
=
numSM
self
.
numSM
=
numSM
self
.
peer_pool
=
peer_pool
self
.
signals
=
peer_pool
.
allocate_peer_tensors
([
2
,
4
],
torch
.
int32
,
False
,
False
)
self
.
signals
[
self
.
rank_in_group
].
zero_
()
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
inplace
=
False
if
left_input_halo
is
None
and
right_input_halo
is
None
else
True
inplace
=
False
if
left_input_halo
is
None
and
right_input_halo
is
None
else
True
...
@@ -102,9 +107,9 @@ class HaloExchangerPeer(HaloExchanger):
...
@@ -102,9 +107,9 @@ class HaloExchangerPeer(HaloExchanger):
right_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
right_output_halo
.
shape
),
right_output_halo
.
dtype
,
channels_last
,
True
)
right_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
right_output_halo
.
shape
),
right_output_halo
.
dtype
,
channels_last
,
True
)
pm
.
push_pull_halos_1d
(
pm
.
push_pull_halos_1d
(
self
.
diagnostics
,
self
.
explicit_nhwc
,
self
.
numSM
,
self
.
diagnostics
,
self
.
explicit_nhwc
,
self
.
numSM
,
left_output_halo
,
left_tx
[
self
.
peer_rank
],
right_tx
[
self
.
left_neighbor
],
left_input_halo
,
left_output_halo
,
left_tx
[
self
.
rank_in_group
],
right_tx
[
self
.
wrap_around_left_rank_in_group
],
left_input_halo
,
right_output_halo
,
right_tx
[
self
.
peer_rank
],
left_tx
[
self
.
right_neighbor
],
right_input_halo
,
right_output_halo
,
right_tx
[
self
.
rank_in_group
],
left_tx
[
self
.
wrap_around_right_rank_in_group
],
right_input_halo
,
self
.
signals
[
self
.
left_neighbor
],
self
.
signals
[
self
.
right_neighbor
],
self
.
signals
[
self
.
peer_rank
]
self
.
signals
[
self
.
wrap_around_left_rank_in_group
],
self
.
signals
[
self
.
wrap_around_right_rank_in_group
],
self
.
signals
[
self
.
rank_in_group
]
)
)
# TODO: Add to push_pull_halos_1d kernel
# TODO: Add to push_pull_halos_1d kernel
if
self
.
left_zero
:
if
self
.
left_zero
:
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
View file @
f687e7fa
...
@@ -19,8 +19,6 @@
...
@@ -19,8 +19,6 @@
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"get_unique_nccl_id"
,
&
apex
::
contrib
::
nccl_p2p
::
get_unique_nccl_id
,
"get_unique_nccl_id"
);
m
.
def
(
"get_unique_nccl_id"
,
&
apex
::
contrib
::
nccl_p2p
::
get_unique_nccl_id
,
"get_unique_nccl_id"
);
m
.
def
(
"init_nccl_comm"
,
&
apex
::
contrib
::
nccl_p2p
::
init_nccl_comm
,
"init_nccl_comm"
);
m
.
def
(
"init_nccl_comm"
,
&
apex
::
contrib
::
nccl_p2p
::
init_nccl_comm
,
"init_nccl_comm"
);
m
.
def
(
"nccl_send"
,
&
apex
::
contrib
::
nccl_p2p
::
nccl_send
,
"nccl_send"
);
m
.
def
(
"nccl_recv"
,
&
apex
::
contrib
::
nccl_p2p
::
nccl_recv
,
"nccl_recv"
);
m
.
def
(
"left_right_halo_exchange_inplace"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange_inplace
,
"left_right_halo_exchange_inplace"
);
m
.
def
(
"left_right_halo_exchange_inplace"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange_inplace
,
"left_right_halo_exchange_inplace"
);
m
.
def
(
"left_right_halo_exchange"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange
,
"left_right_halo_exchange"
);
m
.
def
(
"left_right_halo_exchange"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange
,
"left_right_halo_exchange"
);
m
.
def
(
"add_delay"
,
&
apex
::
contrib
::
nccl_p2p
::
add_delay
,
"add_delay"
);
m
.
def
(
"add_delay"
,
&
apex
::
contrib
::
nccl_p2p
::
add_delay
,
"add_delay"
);
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
View file @
f687e7fa
...
@@ -80,75 +80,82 @@ class NcclCommWrapper
...
@@ -80,75 +80,82 @@ class NcclCommWrapper
ncclCommDestroy
(
comm
);
ncclCommDestroy
(
comm
);
}
}
void
send
(
at
::
Tensor
input
,
int
destination
)
void
left_right_halo_exchange_inplace
(
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
)
{
ncclDataType_t
ncclType
=
get_nccl_type
(
input
);
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
input
.
scalar_type
(),
"nccl_send"
,
[
&
]()
{
size_t
count
=
sizeof
(
scalar_t
)
*
torch
::
numel
(
input
);
auto
input_ptr
=
input
.
data_ptr
<
scalar_t
>
();
ncclSend
(
input_ptr
,
count
,
ncclType
,
destination
,
comm
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
void
recv
(
at
::
Tensor
input
,
int
sender
)
{
ncclDataType_t
ncclType
=
get_nccl_type
(
input
);
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
input
.
scalar_type
(),
"nccl_send"
,
[
&
]()
{
size_t
count
=
sizeof
(
scalar_t
)
*
torch
::
numel
(
input
);
auto
input_ptr
=
input
.
data_ptr
<
scalar_t
>
();
ncclRecv
(
input_ptr
,
count
,
ncclType
,
sender
,
comm
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
void
left_right_halo_exchange_inplace
(
bool
left_zero
,
bool
right_zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
,
int
group_size
)
{
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
ncclGroupStart
();
ncclGroupStart
();
ncclDataType_t
ncclType
=
get_nccl_type
(
left_output_halo
);
ncclDataType_t
ncclType
=
get_nccl_type
(
left_output_halo
);
// we use wrap-around ranks, so left_input_halo of rank 0 has right_output_halo of rank world_size-1 after exchange etc.
bool
left_zero
=
(
left_rank
<
0
);
// this is technically speaking wasteful, but there is no benefit in having the edge ranks do less work than internal ranks.
bool
right_zero
=
(
right_rank
<
0
);
int
group_rank
=
rank
%
group_size
;
int
group_index
=
rank
/
group_size
;
int
prev_rank
=
(
group_rank
+
group_size
-
1
)
%
group_size
;
int
next_rank
=
(
group_rank
+
1
)
%
group_size
;
prev_rank
=
prev_rank
+
group_index
*
group_size
;
next_rank
=
next_rank
+
group_index
*
group_size
;
size_t
left_n
=
torch
::
numel
(
left_output_halo
);
size_t
left_n
=
torch
::
numel
(
left_output_halo
);
size_t
right_n
=
torch
::
numel
(
right_output_halo
);
size_t
right_n
=
torch
::
numel
(
right_output_halo
);
if
(
group_rank
>
0
)
{
assert
(
left_n
>
0
&&
left_n
==
right_n
);
if
(
left_zero
)
{
left_input_halo
.
zero_
();
}
else
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
left_output_halo
.
scalar_type
(),
"left_halo_exch"
,
[
&
]()
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
left_output_halo
.
scalar_type
(),
"left_halo_exch"
,
[
&
]()
{
// send left (to my_rank - 1)
// send left (to my_rank - 1)
ncclSend
(
left_output_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
prev
_rank
,
comm
,
stream
);
ncclSend
(
left_output_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
left
_rank
,
comm
,
stream
);
// receive left (from my_rank - 1)
// receive left (from my_rank - 1)
ncclRecv
(
left_input_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
prev
_rank
,
comm
,
stream
);
ncclRecv
(
left_input_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
left
_rank
,
comm
,
stream
);
});
});
}
}
if
(
group_rank
<
group_size
-
1
)
{
if
(
right_zero
)
{
right_input_halo
.
zero_
();
}
else
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
right_output_halo
.
scalar_type
(),
"right_halo_exch"
,
[
&
]()
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
right_output_halo
.
scalar_type
(),
"right_halo_exch"
,
[
&
]()
{
// send right (to my_rank + 1 )
// send right (to my_rank + 1 )
ncclSend
(
right_output_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
nex
t_rank
,
comm
,
stream
);
ncclSend
(
right_output_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
righ
t_rank
,
comm
,
stream
);
// receive right (from my_rank + 1)
// receive right (from my_rank + 1)
ncclRecv
(
right_input_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
nex
t_rank
,
comm
,
stream
);
ncclRecv
(
right_input_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
righ
t_rank
,
comm
,
stream
);
});
});
}
}
ncclGroupEnd
();
ncclGroupEnd
();
if
(
left_zero
)
left_input_halo
.
zero_
();
if
(
right_zero
)
right_input_halo
.
zero_
();
}
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
bool
left_
zero
,
bool
right_
zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
)
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
left_
rank
,
int
right_
rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
)
{
{
// after halo exchange:
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto
right_input_halo
=
torch
::
empty_like
(
left_output_halo
);
auto
right_input_halo
=
torch
::
empty_like
(
left_output_halo
);
auto
left_input_halo
=
torch
::
empty_like
(
right_output_halo
);
auto
left_input_halo
=
torch
::
empty_like
(
right_output_halo
);
left_right_halo_exchange_inplace
(
left_
zero
,
right_
zero
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
group_size
);
left_right_halo_exchange_inplace
(
left_
rank
,
right_
rank
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
);
return
{
left_input_halo
,
right_input_halo
};
return
{
left_input_halo
,
right_input_halo
};
}
}
};
};
std
::
vector
<
NcclCommWrapper
>
nccl_comms
;
class
ManagedObjects
{
public:
ManagedObjects
()
{
}
~
ManagedObjects
()
{
for
(
auto
it
=
_nccl_comms
.
begin
();
it
!=
_nccl_comms
.
end
();
++
it
)
{
delete
*
it
;
}
}
int
add_comm
(
NcclCommWrapper
*
comm
)
{
int
handle
=
_nccl_comms
.
size
();
_nccl_comms
.
push_back
(
comm
);
return
handle
;
}
NcclCommWrapper
&
get_comm
(
int
handle
)
{
assert
(
handle
>=
0
&&
handle
<
_nccl_comms
.
size
());
return
*
_nccl_comms
[
handle
];
}
private:
std
::
vector
<
NcclCommWrapper
*>
_nccl_comms
;
};
class
ManagedObjects
mo
;
}
// end anonymous namespace
}
// end anonymous namespace
...
@@ -158,7 +165,7 @@ at::Tensor get_unique_nccl_id(int n)
...
@@ -158,7 +165,7 @@ at::Tensor get_unique_nccl_id(int n)
{
{
ncclUniqueId
id
;
ncclUniqueId
id
;
ncclGetUniqueId
(
&
id
);
ncclGetUniqueId
(
&
id
);
auto
id_tensor
=
torch
::
empty
({
n
*
(
int
)
sizeof
(
ncclUniqueId
)},
torch
::
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
).
requires_grad
(
false
));
auto
id_tensor
=
torch
::
empty
({
n
,
(
int
)
sizeof
(
ncclUniqueId
)},
torch
::
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
).
requires_grad
(
false
));
auto
id_ptr
=
id_tensor
.
data_ptr
<
uint8_t
>
();
auto
id_ptr
=
id_tensor
.
data_ptr
<
uint8_t
>
();
size_t
offset
=
0
;
size_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
for
(
int
i
=
0
;
i
<
n
;
++
i
)
...
@@ -177,38 +184,21 @@ int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks)
...
@@ -177,38 +184,21 @@ int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks)
auto
unique_nccl_id_ptr
=
unique_nccl_id
.
data_ptr
<
uint8_t
>
();
auto
unique_nccl_id_ptr
=
unique_nccl_id
.
data_ptr
<
uint8_t
>
();
memcpy
(
&
id
,
unique_nccl_id_ptr
,
sizeof
(
ncclUniqueId
));
memcpy
(
&
id
,
unique_nccl_id_ptr
,
sizeof
(
ncclUniqueId
));
NcclCommWrapper
*
comm
=
new
NcclCommWrapper
(
id
,
my_rank
,
num_ranks
);
NcclCommWrapper
*
comm
=
new
NcclCommWrapper
(
id
,
my_rank
,
num_ranks
);
int
handle
=
nccl_comms
.
size
();
int
handle
=
mo
.
add_comm
(
comm
);
nccl_comms
.
push_back
(
*
comm
);
comm
=
0L
;
comm
=
0L
;
return
handle
;
return
handle
;
}
}
void
nccl_send
(
int
handle
,
at
::
Tensor
input
,
int
destination
)
void
left_right_halo_exchange_inplace
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
)
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
communicator
=
nccl_comms
[
handle
];
communicator
.
send
(
input
,
destination
);
}
void
nccl_recv
(
int
handle
,
at
::
Tensor
input
,
int
sender
)
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
communicator
=
nccl_comms
[
handle
];
communicator
.
recv
(
input
,
sender
);
}
void
left_right_halo_exchange_inplace
(
int
handle
,
bool
left_zero
,
bool
right_zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
,
int
group_size
)
{
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
&
communicator
=
mo
.
get_comm
(
handle
);
class
NcclCommWrapper
&
communicator
=
nccl_comms
[
handle
];
return
communicator
.
left_right_halo_exchange_inplace
(
left_rank
,
right_rank
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
);
return
communicator
.
left_right_halo_exchange_inplace
(
left_zero
,
right_zero
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
group_size
);
}
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
bool
left_
zero
,
bool
right_
zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
)
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
int
left_
rank
,
int
right_
rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
)
{
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
&
communicator
=
mo
.
get_comm
(
handle
);
class
NcclCommWrapper
&
communicator
=
nccl_comms
[
handle
];
return
communicator
.
left_right_halo_exchange
(
left_rank
,
right_rank
,
left_output_halo
,
right_output_halo
);
return
communicator
.
left_right_halo_exchange
(
left_zero
,
right_zero
,
left_output_halo
,
right_output_halo
,
group_size
);
}
}
void
add_delay
(
int
delay
)
void
add_delay
(
int
delay
)
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
View file @
f687e7fa
...
@@ -26,33 +26,20 @@ int init_nccl_comm(
...
@@ -26,33 +26,20 @@ int init_nccl_comm(
int
my_rank
,
int
my_rank
,
int
num_ranks
int
num_ranks
);
);
void
nccl_send
(
int
handle
,
at
::
Tensor
input
,
int
destination
);
void
nccl_recv
(
int
handle
,
at
::
Tensor
input
,
int
sender
);
void
left_right_halo_exchange_inplace
(
void
left_right_halo_exchange_inplace
(
int
handle
,
int
handle
,
bool
left_
zero
,
int
left_
rank
,
bool
right_
zero
,
int
right_
rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
,
at
::
Tensor
right_input_halo
);
int
group_size
);
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
int
handle
,
bool
left_
zero
,
int
left_
rank
,
bool
right_
zero
,
int
right_
rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
right_output_halo
);
int
group_size
);
void
add_delay
(
int
delay
);
void
add_delay
(
int
delay
);
}}}
}}}
#endif
#endif
apex/contrib/peer_memory/peer_memory.py
View file @
f687e7fa
...
@@ -4,23 +4,40 @@ import peer_memory_cuda as pm
...
@@ -4,23 +4,40 @@ import peer_memory_cuda as pm
class
PeerMemoryPool
(
object
):
class
PeerMemoryPool
(
object
):
def
__init__
(
self
,
rank
,
world_size
,
peer_group_size
,
static_size
,
dynamic_size
):
def
__init__
(
self
,
static_size
,
dynamic_size
,
peer_ranks
=
None
):
self
.
peer_group
=
rank
//
peer_group_size
rank
=
torch
.
distributed
.
get_rank
()
self
.
peer_rank
=
rank
%
peer_group_size
world_size
=
torch
.
distributed
.
get_world_size
()
self
.
peer_group_size
=
peer_group_size
ngpus
=
min
(
torch
.
cuda
.
device_count
(),
world_size
)
peer_group_size
=
ngpus
peer_group
=
rank
//
ngpus
peer_rank_base
=
peer_group
*
ngpus
peer_rank
=
rank
-
peer_rank_base
if
peer_ranks
is
None
:
peer_ranks
=
[
i
+
peer_rank_base
for
i
in
range
(
peer_group_size
)]
peer_rank_start
=
peer_rank_base
peer_rank_end
=
peer_rank_start
+
peer_group_size
-
1
for
pr
in
peer_ranks
:
assert
(
pr
>=
peer_rank_start
and
pr
<=
peer_rank_end
),
"%d :: peer_rank %d not on same node (ranks=[%d,%d])"
%
(
rank
,
pr
,
peer_rank_start
,
peer_rank_end
)
self
.
alignment
=
256
self
.
alignment
=
256
self
.
static_size
=
((
static_size
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
static_size
=
((
static_size
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
dynamic_size
=
((
dynamic_size
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
self
.
dynamic_size
=
((
dynamic_size
+
self
.
alignment
-
1
)
//
self
.
alignment
)
*
self
.
alignment
# allocate giant pool of device memory
# allocate giant pool of device memory
self
.
raw
=
pm
.
allocate_raw
(
self
.
static_size
+
self
.
dynamic_size
)
self
.
raw
=
pm
.
allocate_raw
(
self
.
static_size
+
self
.
dynamic_size
)
# exchange peer pointers with nccl
# exchange peer pointers with nccl
raw_ipc
=
pm
.
get_raw_ipc_address
(
self
.
raw
).
cuda
()
raw_ipc
=
pm
.
get_raw_ipc_address
(
self
.
raw
).
cuda
()
peer_raw_ipcs
=
[
torch
.
empty_like
(
raw_ipc
)
for
_
in
range
(
world_size
)]
peer_raw_ipcs
=
[
torch
.
empty_like
(
raw_ipc
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
peer_raw_ipcs
,
raw_ipc
)
torch
.
distributed
.
all_gather
(
peer_raw_ipcs
,
raw_ipc
)
peer_raw_ipcs
=
torch
.
stack
(
peer_raw_ipcs
).
cpu
()
peer_raw_ipcs
=
torch
.
stack
(
peer_raw_ipcs
).
cpu
()
self
.
peer_raw
=
pm
.
get_raw_peers
(
peer_raw_ipcs
,
self
.
peer_rank
,
self
.
raw
)
# extract IPC pointers for ranks on same node
peer_raw
=
pm
.
get_raw_peers
(
peer_raw_ipcs
[
peer_rank_base
:
peer_rank_base
+
ngpus
],
peer_rank
,
self
.
raw
)
self
.
peer_raw
=
[
peer_raw
[
peer_rank
-
peer_rank_base
]
for
peer_rank
in
peer_ranks
]
self
.
static_offset
=
0
self
.
static_offset
=
0
self
.
dynamic_offset
=
0
self
.
dynamic_offset
=
0
self
.
peer_ranks
=
peer_ranks
def
__del__
(
self
):
def
__del__
(
self
):
pm
.
free_raw
(
self
.
raw
)
pm
.
free_raw
(
self
.
raw
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment