Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
bec558b1
Commit
bec558b1
authored
Apr 08, 2022
by
Thor Johnsen
Browse files
Add graphing, switch to peer mem exchanger as default
parent
4aeb24cb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
17 deletions
+34
-17
apex/contrib/bottleneck/bottleneck_module_test.py
apex/contrib/bottleneck/bottleneck_module_test.py
+34
-17
No files found.
apex/contrib/bottleneck/bottleneck_module_test.py
View file @
bec558b1
...
...
@@ -49,16 +49,29 @@ def rel_diff(x1, x2):
return
rel_diff_t
(
x1
,
x2
)
def
fprop_and_bprop
(
x
,
bottleneck
,
dy
=
None
):
def
graph_it
(
bottleneck
,
x
):
print
(
"Graphing"
)
with
torch
.
no_grad
():
x
=
x
.
clone
()
x
.
grad
=
None
x
.
requires_grad
=
True
y
=
bottleneck
(
x
)
if
dy
is
None
:
with
torch
.
no_grad
():
return
torch
.
cuda
.
make_graphed_callables
(
bottleneck
,
(
x
,))
def
clone_inputs
(
bottleneck
,
x
,
dy
=
None
):
with
torch
.
no_grad
():
x
=
x
.
clone
()
x
.
grad
=
None
x
.
requires_grad
=
True
if
dy
is
None
:
y
=
bottleneck
(
x
)
dy
=
torch
.
randn_like
(
y
)
/
1e2
torch
.
distributed
.
broadcast
(
dy
,
0
)
return
x
,
dy
def
fprop_and_bprop
(
bottleneck
,
x
,
dy
):
y
=
bottleneck
(
x
)
y
.
backward
(
dy
)
dgrad
=
x
.
grad
.
detach
()
wgrad
=
{}
...
...
@@ -74,7 +87,8 @@ def ground_truth(N, C, H, W, dtype, memory_format, bottleneck):
with
torch
.
no_grad
():
x
=
torch
.
randn
([
N
,
H
,
W
,
C
],
dtype
=
dtype
,
device
=
'cuda'
)
torch
.
distributed
.
broadcast
(
x
,
0
)
return
fprop_and_bprop
(
x
,
bottleneck
)
x
,
dy
=
clone_inputs
(
bottleneck
,
x
)
return
fprop_and_bprop
(
bottleneck
,
x
,
dy
)
else
:
# 2 -> native nhwc
# 3 -> nchw
...
...
@@ -92,11 +106,9 @@ def print_ground_truth(gt):
def
apply_to_different_bottleneck
(
gt
,
bottleneck
):
with
torch
.
no_grad
():
x
,
y
,
dy
,
dgrad
,
wgrad
=
gt
x
=
x
.
clone
()
x
.
requires_grad
=
True
dy
=
dy
.
clone
()
return
fprop_and_bprop
(
x
,
bottleneck
,
dy
)
x
,
_
,
dy
,
_
,
_
=
gt
x
,
dy
=
clone_inputs
(
bottleneck
,
x
,
dy
)
return
fprop_and_bprop
(
bottleneck
,
x
,
dy
)
def
compare_single_field
(
results
,
f1
,
f2
,
l0
,
l1
,
l2
):
...
...
@@ -161,15 +173,20 @@ def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp3
spatial_group_rank
=
rank
spatial_communicator
=
None
spatial_halo_exchanger
=
halex
spatial_method
=
3
# 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
spatial_parallel_args
=
(
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
)
spatial_method
=
1
# 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
use_delay_kernel
=
False
spatial_parallel_args
=
(
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
,
use_delay_kernel
)
spatial_bottleneck
=
spatial_parallel_bottleneck
(
C
,
dtype
,
explicit_nhwc
,
gt_bottleneck
,
spatial_parallel_args
)
with
torch
.
no_grad
():
Hs
=
H
//
spatial_group_size
xs
=
x
[:,
spatial_group_rank
*
Hs
:(
spatial_group_rank
+
1
)
*
Hs
,:,:]
dys
=
dy
[:,
spatial_group_rank
*
Hs
:(
spatial_group_rank
+
1
)
*
Hs
,:,:]
_
,
y
,
_
,
dgrad
,
wgrad
=
fprop_and_bprop
(
xs
,
spatial_bottleneck
,
dys
)
xs
=
x
[:,
spatial_group_rank
*
Hs
:(
spatial_group_rank
+
1
)
*
Hs
,:,:].
clone
()
dys
=
dy
[:,
spatial_group_rank
*
Hs
:(
spatial_group_rank
+
1
)
*
Hs
,:,:].
clone
()
xs
.
requires_grad
=
True
spatial_bottleneck
=
graph_it
(
spatial_bottleneck
,
xs
)
_
,
y
,
_
,
dgrad
,
wgrad
=
fprop_and_bprop
(
spatial_bottleneck
,
xs
,
dys
)
# gather output pieces
for
n
,
p
in
wgrad
.
items
():
if
fp32_reduce
:
...
...
@@ -217,9 +234,9 @@ def main():
peer_pool
=
PeerMemoryPool
(
rank
,
world_size
,
spatial_group_size
,
64
*
1024
*
1024
,
2
*
1024
*
1024
)
#halex = HaloExchangerAllGather(world_size, spatial_group_size, rank, spatial_communicator)
halex
=
HaloExchangerSendRecv
(
world_size
,
spatial_group_size
,
rank
,
spatial_communicator
)
#
halex = HaloExchangerSendRecv(world_size, spatial_group_size, rank, spatial_communicator)
#
halex = HaloExchangerPeer(world_size, spatial_group_size, rank, spatial_communicator, peer_pool, explicit_nhwc, numSM=1)
halex
=
HaloExchangerPeer
(
world_size
,
spatial_group_size
,
rank
,
spatial_communicator
,
peer_pool
,
explicit_nhwc
,
numSM
=
1
)
#print("halex.signals = %s" % (str(halex.signals)))
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
#torch.cuda.synchronize()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment