Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
b41c68b3
Commit
b41c68b3
authored
Mar 25, 2022
by
Thor Johnsen
Browse files
Optional inplace halo exchange
parent
778808eb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
128 additions
and
17 deletions
+128
-17
apex/contrib/bottleneck/halo_exchangers.py
apex/contrib/bottleneck/halo_exchangers.py
+101
-9
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
+1
-0
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
+20
-8
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
+6
-0
No files found.
apex/contrib/bottleneck/halo_exchangers.py
View file @
b41c68b3
...
...
@@ -16,8 +16,12 @@ class HaloExchangerNoComm(HaloExchanger):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerNoComm
,
self
).
__init__
()
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
return
right_output_halo
,
left_output_halo
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
if
left_input_halo
is
None
:
return
right_output_halo
,
left_output_halo
else
:
left_input_halo
.
copy_
(
right_output_halo
)
right_input_halo
.
copy_
(
left_output_halo
)
class
HaloExchangerAllGather
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
...
...
@@ -26,7 +30,7 @@ class HaloExchangerAllGather(HaloExchanger):
self
.
local_rank
=
rank
%
spatial_group_size
self
.
comm
=
comm
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
N
,
Hh
,
W
,
C
=
list
(
left_output_halo
.
shape
)
send_halos
=
torch
.
empty
((
N
,
2
*
Hh
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
send_halos
[:,:
Hh
,:,:].
copy_
(
left_output_halo
)
...
...
@@ -34,9 +38,13 @@ class HaloExchangerAllGather(HaloExchanger):
all_halos
=
torch
.
empty
((
N
,
2
*
Hh
*
self
.
spatial_group_size
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
*
Hh
:(
i
+
1
)
*
2
*
Hh
,:,:]
for
i
in
range
(
self
.
spatial_group_size
)]
torch
.
distributed
.
all_gather
(
all_halos
,
send_halos
,
group
=
self
.
comm
,
no_copy
=
True
)
left_input_halo
=
all_halos
[(
self
.
spatial_group_size
+
self
.
local_rank
-
1
)
%
self
.
spatial_group_size
][:,
Hh
:,:,:]
right_input_halo
=
all_halos
[(
self
.
local_rank
+
1
)
%
self
.
spatial_group_size
][:,:
Hh
,:,:]
return
left_input_halo
,
right_input_halo
ag_left_input_halo
=
all_halos
[(
self
.
spatial_group_size
+
self
.
local_rank
-
1
)
%
self
.
spatial_group_size
][:,
Hh
:,:,:]
ag_right_input_halo
=
all_halos
[(
self
.
local_rank
+
1
)
%
self
.
spatial_group_size
][:,:
Hh
,:,:]
if
left_input_halo
is
None
:
return
ag_left_input_halo
,
ag_right_input_halo
else
:
left_input_halo
.
copy_
(
ag_left_input_halo
)
right_input_halo
.
copy_
(
ag_right_input_halo
)
class
HaloExchangerSendRecv
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
...
...
@@ -48,6 +56,90 @@ class HaloExchangerSendRecv(HaloExchanger):
nccl_id
=
nccl_id
.
cpu
()
self
.
handle
=
inc
.
init_nccl_comm
(
nccl_id
,
rank
,
world_size
)
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
left_input_halo
,
right_input_halo
=
inc
.
left_right_halo_exchange
(
self
.
handle
,
left_output_halo
,
right_output_halo
,
self
.
spatial_group_size
)
return
left_input_halo
,
right_input_halo
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
if
left_input_halo
is
None
:
left_input_halo
,
right_input_halo
=
inc
.
left_right_halo_exchange
(
self
.
handle
,
left_output_halo
,
right_output_halo
,
self
.
spatial_group_size
)
return
left_input_halo
,
right_input_halo
else
:
inc
.
left_right_halo_exchange_inplace
(
self
.
handle
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
self
.
spatial_group_size
)
class
HaloExchangerPeer
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
,
peer_pool
,
explicit_nhwc
,
numSM
=
1
):
super
(
HaloExchangerPeer
,
self
).
__init__
()
self
.
diagnostics
=
False
self
.
spatial_group_size
=
spatial_group_size
self
.
peer_rank
=
rank
%
spatial_group_size
self
.
left_neighbor
=
(
self
.
peer_rank
+
self
.
spatial_group_size
-
1
)
%
self
.
spatial_group_size
self
.
right_neighbor
=
(
self
.
peer_rank
+
1
)
%
self
.
spatial_group_size
self
.
peer_pool
=
peer_pool
self
.
signals
=
peer_pool
.
allocate_peer_tensors
([
2
,
4
],
torch
.
int32
,
False
,
False
)
self
.
signals
[
self
.
peer_rank
].
zero_
()
self
.
explicit_nhwc
=
explicit_nhwc
self
.
numSM
=
numSM
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
channels_last
=
left_output_halo
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
and
not
self
.
explicit_nhwc
left_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
left_out_halo
.
shape
),
left_out_halo
.
dtype
,
channels_last
,
True
)
right_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
right_out_halo
.
shape
),
right_out_halo
.
dtype
,
channels_last
,
True
)
pm
.
push_pull_halos_1d
(
self
.
diagnostics
,
self
.
explicit_nhwc
,
self
.
numSM
,
left_output_halo
,
left_tx
[
self
.
peer_rank
],
right_tx
[
top_neighbor
],
left_input_halo
,
right_output_halo
,
right_tx
[
self
.
peer_rank
],
left_tx
[
btm_neighbor
],
right_input_halo
,
self
.
signals
[
left_neighbor
],
self
.
signals
[
right_neighbor
],
self
.
signals
[
self
.
peer_rank
]
)
# Class that combines input volume with halos from neighbors (1d).
class
HaloPadder
:
def
__init__
(
self
,
halo_ex
):
self
.
halo_ex
=
halo_ex
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream2
=
torch
.
cuda
.
Stream
()
def
__call__
(
self
,
y
,
half_halo
,
explicit_nhwc
,
H_split
):
channels_last
=
not
explicit_nhwc
and
y
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
if
explicit_nhwc
:
N
,
H
,
W
,
C
=
list
(
y
.
shape
)
if
H_split
:
padded_shape
=
[
N
,
H
+
2
*
half_halo
,
W
,
C
]
ypad
=
torch
.
empty
(
shape
=
padded_shape
,
dtype
=
y
.
dtype
,
device
=
y
.
device
,
memory_format
=
torch
.
contiguous_format
)
yleft
=
ypad
[:,:
half_halo
,:,:]
ymid
=
ypad
[:,
half_halo
:
H
+
half_halo
,:,:]
yright
=
ypad
[:,
H
+
half_halo
:
H
+
2
*
half_halo
,:,:]
oleft
=
y
[:,:
half_halo
,:,:]
oright
=
y
[:,
H
-
half_halo
:,:,:]
else
:
padded_shape
=
[
N
,
H
,
W
+
2
*
half_halo
,
C
]
ypad
=
torch
.
empty
(
shape
=
padded_shape
,
dtype
=
y
.
dtype
,
device
=
y
.
device
,
memory_format
=
torch
.
contiguous_format
)
yleft
=
ypad
[:,:,:
half_halo
,:]
ymid
=
ypad
[:,:,
half_halo
:
W
+
half_halo
,:]
yright
=
ypad
[:,:,
W
+
half_halo
:
W
+
2
*
half_halo
,:]
oleft
=
y
[:,:,:
half_halo
,:]
oright
=
y
[:,:,
W
-
half_halo
:,:]
else
:
N
,
C
,
H
,
W
=
list
(
y
.
shape
)
if
H_split
:
padded_shape
=
[
N
,
C
,
H
+
2
*
half_halo
,
W
]
ypad
=
torch
.
empty
(
shape
=
padded_shape
,
dtype
=
y
.
dtype
,
device
=
y
.
device
,
memory_format
=
torch
.
channels_last
)
yleft
=
ypad
[:,:,:
half_halo
,:]
ymid
=
ypad
[:,:,
half_halo
:
H
+
half_halo
,:]
yright
=
ypad
[:,:,
H
+
half_halo
:
H
+
2
*
half_halo
,:]
oleft
=
y
[:,:,:
half_halo
,:]
oright
=
y
[:,:,
H
-
half_halo
:,:]
else
:
padded_shape
=
[
N
,
C
,
H
,
W
+
2
*
half_halo
]
ypad
=
torch
.
empty
(
shape
=
padded_shape
,
dtype
=
y
.
dtype
,
device
=
y
.
device
,
memory_format
=
torch
.
channels_last
)
yleft
=
ypad
[:,:,:,:
half_halo
]
ymid
=
ypad
[:,:,:,
half_halo
:
W
+
half_halo
]
yright
=
ypad
[:,:,:,
W
+
half_halo
:
W
+
2
*
half_halo
]
oleft
=
y
[:,:,:,:
half_halo
]
oright
=
y
[:,:,:,
W
-
half_halo
:]
with
torch
.
cuda
.
stream
(
self
.
stream1
):
self
.
halo_ex
(
oleft
,
oright
,
yleft
,
yright
)
with
torch
.
cuda
.
stream
(
self
.
stream2
):
ymid
.
copy_
(
y
)
return
ypad
def
wait
(
self
):
current_stream
=
torch
.
cuda
.
current_stream
()
current_stream
.
wait_stream
(
self
.
stream1
)
current_stream
.
wait_stream
(
self
.
stream2
)
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
View file @
b41c68b3
...
...
@@ -21,6 +21,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"init_nccl_comm"
,
&
apex
::
contrib
::
nccl_p2p
::
init_nccl_comm
,
"init_nccl_comm"
);
m
.
def
(
"nccl_send"
,
&
apex
::
contrib
::
nccl_p2p
::
nccl_send
,
"nccl_send"
);
m
.
def
(
"nccl_recv"
,
&
apex
::
contrib
::
nccl_p2p
::
nccl_recv
,
"nccl_recv"
);
m
.
def
(
"left_right_halo_exchange_inplace"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange_inplace
,
"left_right_halo_exchange_inplace"
);
m
.
def
(
"left_right_halo_exchange"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange
,
"left_right_halo_exchange"
);
m
.
def
(
"add_delay"
,
&
apex
::
contrib
::
nccl_p2p
::
add_delay
,
"add_delay"
);
}
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
View file @
b41c68b3
...
...
@@ -100,14 +100,9 @@ class NcclCommWrapper
});
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
)
{
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
void
left_right_halo_exchange_inplace
(
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
,
int
group_size
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
right_input_halo
=
torch
::
empty_like
(
left_output_halo
);
auto
left_input_halo
=
torch
::
empty_like
(
right_output_halo
);
ncclGroupStart
();
ncclDataType_t
ncclType
=
get_nccl_type
(
left_output_halo
);
// we use wrap-around ranks, so left_input_halo of rank 0 has right_output_halo of rank world_size-1 after exchange etc.
...
...
@@ -137,7 +132,17 @@ class NcclCommWrapper
});
}
ncclGroupEnd
();
return
{
left_input_halo
,
right_input_halo
};
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
)
{
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto
right_input_halo
=
torch
::
empty_like
(
left_output_halo
);
auto
left_input_halo
=
torch
::
empty_like
(
right_output_halo
);
left_right_halo_exchange_inplace
(
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
group_size
);
return
{
left_input_halo
,
right_input_halo
};
}
};
...
...
@@ -190,6 +195,13 @@ void nccl_recv(int handle, at::Tensor input, int sender)
communicator
.
recv
(
input
,
sender
);
}
void
left_right_halo_exchange_inplace
(
int
handle
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
,
int
group_size
)
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
&
communicator
=
nccl_comms
[
handle
];
return
communicator
.
left_right_halo_exchange_inplace
(
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
group_size
);
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
)
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
View file @
b41c68b3
...
...
@@ -36,6 +36,12 @@ void nccl_recv(
at
::
Tensor
input
,
int
sender
);
void
left_right_halo_exchange_inplace
(
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
,
int
group_size
);
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
at
::
Tensor
left_output_halo
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment