Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
e510b003
Commit
e510b003
authored
Mar 24, 2022
by
Thor Johnsen
Browse files
Sample 1d peer memory halo exchanger
parent
a61f0c25
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
63 deletions
+70
-63
apex/contrib/peer_memory/__init__.py
apex/contrib/peer_memory/__init__.py
+1
-0
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
+8
-63
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
+61
-0
No files found.
apex/contrib/peer_memory/__init__.py
View file @
e510b003
from
.peer_memory
import
PeerMemoryPool
from
.peer_memory
import
PeerMemoryPool
from
.peer_halo_exchanger_1d
import
PeerHaloExchanger1d
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
View file @
e510b003
import
torch
import
torch
from
apex.contrib.peer_memory
import
PeerMemoryPool
from
apex.contrib.peer_memory
import
PeerMemoryPool
,
PeerHaloExchanger1d
import
peer_memory
as
pm
import
peer_memory
as
pm
# How to run:
class
HaloExchangerPeerMemory
:
# torchrun --nproc_per_node <num-GPU> <this-python-prog>
def
__init__
(
self
,
rank
,
peer_group_size
,
peer_pool
):
# <num-GPU> must be a power of 2 greater than 1.
self
.
peer_group_size
=
peer_group_size
self
.
peer_rank
=
rank
%
peer_group_size
self
.
peer_pool
=
peer_pool
self
.
signals
=
peer_pool
.
allocate_peer_tensors
([
2
,
4
],
torch
.
int32
,
False
,
False
)
self
.
signals
[
self
.
peer_rank
].
zero_
()
def
__call__
(
self
,
y
,
half_halo
,
H_split
=
True
,
explicit_nhwc
=
False
,
numSM
=
1
):
channels_last
=
y
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
if
H_split
:
if
explicit_nhwc
:
_
,
Hs
,
_
,
_
=
list
(
y
.
shape
)
H
=
Hs
-
2
*
half_halo
top_out_halo
=
y
[:,
half_halo
:
2
*
half_halo
,:,:]
top_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top_out_halo
.
shape
),
top_out_halo
.
dtype
,
False
,
True
)
top_inp_halo
=
y
[:,:
half_halo
,:,:]
btm_out_halo
=
y
[:,
H
:
H
+
half_halo
,:,:]
btm_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm_out_halo
.
shape
),
btm_out_halo
.
dtype
,
False
,
True
)
btm_inp_halo
=
y
[:,
H
+
half_halo
:
H
+
2
*
half_halo
,:,:]
else
:
_
,
_
,
Hs
,
_
=
list
(
y
.
shape
)
H
=
Hs
-
2
*
half_halo
top_out_halo
=
y
[:,:,
half_halo
:
2
*
half_halo
,:]
top_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top_out_halo
.
shape
),
top_out_halo
.
dtype
,
channels_last
,
True
)
top_inp_halo
=
y
[:,:,:
half_halo
,:]
btm_out_halo
=
y
[:,:,
H
:
H
+
half_halo
,:]
btm_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm_out_halo
.
shape
),
btm_out_halo
.
dtype
,
channels_last
,
True
)
btm_inp_halo
=
y
[:,:,
H
+
half_halo
:
H
+
2
*
half_halo
,:]
else
:
if
explicit_nhwc
:
_
,
_
,
Ws
,
_
=
list
(
y
.
shape
)
W
=
Ws
-
2
*
half_halo
top_out_halo
=
y
[:,:,
half_halo
:
2
*
half_halo
,:]
top_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top_out_halo
.
shape
),
top_out_halo
.
dtype
,
False
,
True
)
top_inp_halo
=
y
[:,:,:
half_halo
,:]
btm_out_halo
=
y
[:,:,
W
:
W
+
half_halo
,:]
btm_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm_out_halo
.
shape
),
btm_out_halo
.
dtype
,
False
,
True
)
btm_inp_halo
=
y
[:,:,
W
+
half_halo
:
W
+
2
*
half_halo
,:]
else
:
_
,
_
,
_
,
Ws
=
list
(
y
.
shape
)
W
=
Ws
-
2
*
half_halo
top_out_halo
=
y
[:,:,:,
half_halo
:
2
*
half_halo
]
top_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top_out_halo
.
shape
),
top_out_halo
.
dtype
,
channels_last
,
True
)
top_inp_halo
=
y
[:,:,:,:
half_halo
]
btm_out_halo
=
y
[:,:,:,
W
:
W
+
half_halo
]
btm_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm_out_halo
.
shape
),
btm_out_halo
.
dtype
,
channels_last
,
True
)
btm_inp_halo
=
y
[:,:,:,
W
+
half_halo
:
W
+
2
*
half_halo
]
top_neighbor
=
(
self
.
peer_rank
+
self
.
peer_group_size
-
1
)
%
self
.
peer_group_size
btm_neighbor
=
(
self
.
peer_rank
+
1
)
%
self
.
peer_group_size
pm
.
push_pull_halos_1d
(
False
,
#True if self.peer_rank == 0 else False,
explicit_nhwc
,
numSM
,
top_out_halo
,
top_tx
[
self
.
peer_rank
],
btm_tx
[
top_neighbor
],
top_inp_halo
,
btm_out_halo
,
btm_tx
[
self
.
peer_rank
],
top_tx
[
btm_neighbor
],
btm_inp_halo
,
self
.
signals
[
top_neighbor
],
self
.
signals
[
btm_neighbor
],
self
.
signals
[
self
.
peer_rank
]
)
# Output of this function is used as ground truth in module tests.
def
nccl_halo_ex
(
peer_rank
,
peer_group_size
,
y
,
half_halo
,
explicit_nhwc
,
H_split
):
def
nccl_halo_ex
(
peer_rank
,
peer_group_size
,
y
,
half_halo
,
explicit_nhwc
,
H_split
):
if
explicit_nhwc
:
if
explicit_nhwc
:
if
H_split
:
if
H_split
:
...
@@ -132,7 +78,7 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype,
...
@@ -132,7 +78,7 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype,
y
=
y
.
to
(
memory_format
=
torch
.
channels_last
)
y
=
y
.
to
(
memory_format
=
torch
.
channels_last
)
ym
=
y
[:,:,:,
half_halo
:
W
+
half_halo
]
ym
=
y
[:,:,:,
half_halo
:
W
+
half_halo
]
y2
=
y
.
clone
()
y2
=
y
.
clone
()
halo_ex
(
y
,
half_halo
,
H_split
,
explicit_nhwc
,
numSM
)
halo_ex
(
y
,
H_split
,
explicit_nhwc
,
numSM
)
nccl_halo_ex
(
peer_rank
,
peer_group_size
,
y2
,
half_halo
,
explicit_nhwc
,
H_split
)
nccl_halo_ex
(
peer_rank
,
peer_group_size
,
y2
,
half_halo
,
explicit_nhwc
,
H_split
)
is_equal
=
torch
.
all
(
torch
.
eq
(
y
,
y2
))
is_equal
=
torch
.
all
(
torch
.
eq
(
y
,
y2
))
if
peer_rank
==
0
:
if
peer_rank
==
0
:
...
@@ -184,12 +130,11 @@ def main():
...
@@ -184,12 +130,11 @@ def main():
torch
.
cuda
.
set_device
(
rank
)
torch
.
cuda
.
set_device
(
rank
)
pool
=
PeerMemoryPool
(
rank
,
world_size
,
world_size
,
64
*
1024
,
2
*
1024
*
1024
)
pool
=
PeerMemoryPool
(
rank
,
world_size
,
world_size
,
64
*
1024
,
2
*
1024
*
1024
)
halo_ex
=
HaloExchangerPeerMemory
(
rank
,
world_size
,
pool
)
half_halo
=
1
half_halo
=
1
halo_ex
=
PeerHaloExchanger1d
(
rank
,
world_size
,
pool
,
half_halo
)
H_split_tests
(
1
,
64
,
336
,
200
,
half_halo
,
rank
,
world_size
,
halo_ex
)
H_split_tests
(
1
,
64
,
336
,
200
,
half_halo
,
rank
,
world_size
,
halo_ex
)
W_split_tests
(
1
,
64
,
200
,
336
,
half_halo
,
rank
,
world_size
,
halo_ex
)
W_split_tests
(
1
,
64
,
200
,
336
,
half_halo
,
world_size
,
halo_ex
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
apex/contrib/peer_memory/peer_halo_exchanger_1d.py
0 → 100644
View file @
e510b003
import
torch
from
apex.contrib.peer_memory
import
PeerMemoryPool
import
peer_memory
as
pm
class
PeerHaloExchanger1d
:
def
__init__
(
self
,
rank
,
peer_group_size
,
peer_pool
,
half_halo
):
self
.
peer_group_size
=
peer_group_size
self
.
peer_rank
=
rank
%
peer_group_size
self
.
peer_pool
=
peer_pool
self
.
signals
=
peer_pool
.
allocate_peer_tensors
([
2
,
4
],
torch
.
int32
,
False
,
False
)
self
.
signals
[
self
.
peer_rank
].
zero_
()
self
.
half_halo
=
half_halo
def
__call__
(
self
,
y
,
H_split
=
True
,
explicit_nhwc
=
False
,
numSM
=
1
,
diagnostics
=
False
):
channels_last
=
y
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
and
not
explicit_nhwc
if
H_split
:
if
explicit_nhwc
:
_
,
Hs
,
_
,
_
=
list
(
y
.
shape
)
H
=
Hs
-
2
*
self
.
half_halo
top_out_halo
=
y
[:,
self
.
half_halo
:
2
*
self
.
half_halo
,:,:]
top_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top_out_halo
.
shape
),
top_out_halo
.
dtype
,
False
,
True
)
top_inp_halo
=
y
[:,:
self
.
half_halo
,:,:]
btm_out_halo
=
y
[:,
H
:
H
+
self
.
half_halo
,:,:]
btm_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm_out_halo
.
shape
),
btm_out_halo
.
dtype
,
False
,
True
)
btm_inp_halo
=
y
[:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:,:]
else
:
_
,
_
,
Hs
,
_
=
list
(
y
.
shape
)
H
=
Hs
-
2
*
self
.
half_halo
top_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
top_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top_out_halo
.
shape
),
top_out_halo
.
dtype
,
channels_last
,
True
)
top_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
btm_out_halo
=
y
[:,:,
H
:
H
+
self
.
half_halo
,:]
btm_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm_out_halo
.
shape
),
btm_out_halo
.
dtype
,
channels_last
,
True
)
btm_inp_halo
=
y
[:,:,
H
+
self
.
half_halo
:
H
+
2
*
self
.
half_halo
,:]
else
:
if
explicit_nhwc
:
_
,
_
,
Ws
,
_
=
list
(
y
.
shape
)
W
=
Ws
-
2
*
self
.
half_halo
top_out_halo
=
y
[:,:,
self
.
half_halo
:
2
*
self
.
half_halo
,:]
top_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top_out_halo
.
shape
),
top_out_halo
.
dtype
,
False
,
True
)
top_inp_halo
=
y
[:,:,:
self
.
half_halo
,:]
btm_out_halo
=
y
[:,:,
W
:
W
+
self
.
half_halo
,:]
btm_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm_out_halo
.
shape
),
btm_out_halo
.
dtype
,
False
,
True
)
btm_inp_halo
=
y
[:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
,:]
else
:
_
,
_
,
_
,
Ws
=
list
(
y
.
shape
)
W
=
Ws
-
2
*
self
.
half_halo
top_out_halo
=
y
[:,:,:,
self
.
half_halo
:
2
*
self
.
half_halo
]
top_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
top_out_halo
.
shape
),
top_out_halo
.
dtype
,
channels_last
,
True
)
top_inp_halo
=
y
[:,:,:,:
self
.
half_halo
]
btm_out_halo
=
y
[:,:,:,
W
:
W
+
self
.
half_halo
]
btm_tx
=
self
.
peer_pool
.
allocate_peer_tensors
(
list
(
btm_out_halo
.
shape
),
btm_out_halo
.
dtype
,
channels_last
,
True
)
btm_inp_halo
=
y
[:,:,:,
W
+
self
.
half_halo
:
W
+
2
*
self
.
half_halo
]
top_neighbor
=
(
self
.
peer_rank
+
self
.
peer_group_size
-
1
)
%
self
.
peer_group_size
btm_neighbor
=
(
self
.
peer_rank
+
1
)
%
self
.
peer_group_size
pm
.
push_pull_halos_1d
(
diagnostics
,
explicit_nhwc
,
numSM
,
top_out_halo
,
top_tx
[
self
.
peer_rank
],
btm_tx
[
top_neighbor
],
top_inp_halo
,
btm_out_halo
,
btm_tx
[
self
.
peer_rank
],
top_tx
[
btm_neighbor
],
btm_inp_halo
,
self
.
signals
[
top_neighbor
],
self
.
signals
[
btm_neighbor
],
self
.
signals
[
self
.
peer_rank
]
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment