Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
34df0f79
Commit
34df0f79
authored
Mar 31, 2022
by
Thor Johnsen
Browse files
wgrad2 in parallel stream, optional mode to wait for halo transfer
parent
834b1d01
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
166 additions
and
60 deletions
+166
-60
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+34
-25
apex/contrib/bottleneck/bottleneck_module_test.py
apex/contrib/bottleneck/bottleneck_module_test.py
+10
-4
apex/contrib/bottleneck/halo_exchangers.py
apex/contrib/bottleneck/halo_exchangers.py
+27
-9
apex/contrib/csrc/bottleneck/bottleneck.cpp
apex/contrib/csrc/bottleneck/bottleneck.cpp
+51
-0
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
+9
-7
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
+4
-0
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+1
-0
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
+30
-15
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
34df0f79
...
@@ -253,6 +253,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -253,6 +253,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_out1_halo
=
out1_pad
[:,:
1
,:,:]
top_out1_halo
=
out1_pad
[:,:
1
,:,:]
btm_out1_halo
=
out1_pad
[:,
Hs
+
1
:
Hs
+
2
,:,:]
btm_out1_halo
=
out1_pad
[:,
Hs
+
1
:
Hs
+
2
,:,:]
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:
1
,:,:],
out1
[:,
Hs
-
1
:,:,:],
top_out1_halo
,
btm_out1_halo
)
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:
1
,:,:],
out1
[:,
Hs
-
1
:,:,:],
top_out1_halo
,
btm_out1_halo
)
if
spatial_method
==
1
:
# overlap mid convolution with halo transfer
if
spatial_group_rank
<
spatial_group_size
-
1
:
if
spatial_group_rank
<
spatial_group_size
-
1
:
stream2
.
wait_stream
(
stream1
)
stream2
.
wait_stream
(
stream1
)
with
torch
.
cuda
.
stream
(
stream2
):
with
torch
.
cuda
.
stream
(
stream2
):
...
@@ -260,24 +262,26 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -260,24 +262,26 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
btm_fat_halo
,
args
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
btm_fat_halo
,
args
)
else
:
with
torch
.
cuda
.
stream
(
stream2
):
btm_out1_halo
.
zero_
()
if
spatial_group_rank
>
0
:
if
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
stream1
):
with
torch
.
cuda
.
stream
(
stream1
):
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
top_fat_halo
,
args
)
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
top_fat_halo
,
args
)
else
:
with
torch
.
cuda
.
stream
(
stream1
):
top_out1_halo
.
zero_
()
inc
.
add_delay
(
10
)
inc
.
add_delay
(
10
)
elif
spatial_method
==
2
:
# wait for halo transfer to finish before doing a full convolution of padded x
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
fast_bottleneck
.
forward_out2_pad
(
nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
else
:
assert
(
False
),
"spatial_method must be 1 or 2"
if
spatial_group_size
<=
1
or
spatial_method
==
1
:
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
# compute halo cells for outputs[1] (out2)
# compute halo cells for outputs[1] (out2)
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
and
spatial_method
==
1
:
out2
=
outputs
[
1
]
out2
=
outputs
[
1
]
if
spatial_group_rank
>
0
:
if
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
...
@@ -290,6 +294,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -290,6 +294,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
# save halos for backward pass
# save halos for backward pass
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
out1_pad
,]))
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
out1_pad
,]))
else
:
else
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
...
@@ -368,6 +373,9 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -368,6 +373,9 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
inc
.
add_delay
(
10
)
inc
.
add_delay
(
10
)
wgrad2_stream
=
torch
.
cuda
.
Stream
()
wgrad2_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
wgrad2_stream
):
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
spatial_group_size
>
1
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
else
:
else
:
...
@@ -406,6 +414,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -406,6 +414,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
wgrad2_stream
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
...
...
apex/contrib/bottleneck/bottleneck_module_test.py
View file @
34df0f79
...
@@ -161,8 +161,8 @@ def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp3
...
@@ -161,8 +161,8 @@ def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp3
spatial_group_rank
=
rank
spatial_group_rank
=
rank
spatial_communicator
=
None
spatial_communicator
=
None
spatial_halo_exchanger
=
halex
spatial_halo_exchanger
=
halex
spatial_
stream
=
None
# Not in use
spatial_
method
=
2
# 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
spatial_parallel_args
=
(
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_
stream
)
spatial_parallel_args
=
(
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_
method
)
spatial_bottleneck
=
spatial_parallel_bottleneck
(
C
,
dtype
,
explicit_nhwc
,
gt_bottleneck
,
spatial_parallel_args
)
spatial_bottleneck
=
spatial_parallel_bottleneck
(
C
,
dtype
,
explicit_nhwc
,
gt_bottleneck
,
spatial_parallel_args
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -217,8 +217,14 @@ def main():
...
@@ -217,8 +217,14 @@ def main():
peer_pool
=
PeerMemoryPool
(
rank
,
world_size
,
spatial_group_size
,
64
*
1024
*
1024
,
2
*
1024
*
1024
)
peer_pool
=
PeerMemoryPool
(
rank
,
world_size
,
spatial_group_size
,
64
*
1024
*
1024
,
2
*
1024
*
1024
)
#halex = HaloExchangerAllGather(world_size, spatial_group_size, rank, spatial_communicator)
#halex = HaloExchangerAllGather(world_size, spatial_group_size, rank, spatial_communicator)
#halex = HaloExchangerSendRecv(world_size, spatial_group_size, rank, spatial_communicator)
halex
=
HaloExchangerSendRecv
(
world_size
,
spatial_group_size
,
rank
,
spatial_communicator
)
halex
=
HaloExchangerPeer
(
world_size
,
spatial_group_size
,
rank
,
spatial_communicator
,
peer_pool
,
explicit_nhwc
,
numSM
=
1
)
#halex = HaloExchangerPeer(world_size, spatial_group_size, rank, spatial_communicator, peer_pool, explicit_nhwc, numSM=1)
#print("halex.signals = %s" % (str(halex.signals)))
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
#torch.cuda.synchronize()
#torch.distributed.barrier()
bt2
=
n_way_spatial
(
halex
,
gt_bottleneck
,
gt
,
explicit_nhwc
,
world_size
,
rank
,
fp32_reduce
=
True
)
bt2
=
n_way_spatial
(
halex
,
gt_bottleneck
,
gt
,
explicit_nhwc
,
world_size
,
rank
,
fp32_reduce
=
True
)
compare
(
gt
,
bt2
)
compare
(
gt
,
bt2
)
...
...
apex/contrib/bottleneck/halo_exchangers.py
View file @
34df0f79
...
@@ -9,14 +9,17 @@ import peer_memory as pm
...
@@ -9,14 +9,17 @@ import peer_memory as pm
# NB! This is only useful for performance testing.
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
# NB! Do not use for actual production runs
class
HaloExchanger
(
object
):
class
HaloExchanger
(
object
):
def
__init__
(
self
):
def
__init__
(
self
,
spatial_group_size
,
rank
):
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream2
=
torch
.
cuda
.
Stream
()
self
.
stream2
=
torch
.
cuda
.
Stream
()
self
.
stream3
=
torch
.
cuda
.
Stream
()
self
.
stream3
=
torch
.
cuda
.
Stream
()
spatial_rank
=
rank
%
spatial_group_size
self
.
left_zero
=
True
if
spatial_rank
==
0
else
False
self
.
right_zero
=
True
if
spatial_rank
==
spatial_group_size
-
1
else
False
class
HaloExchangerNoComm
(
HaloExchanger
):
class
HaloExchangerNoComm
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerNoComm
,
self
).
__init__
()
super
(
HaloExchangerNoComm
,
self
).
__init__
(
spatial_group_size
,
rank
)
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
if
left_input_halo
is
None
:
if
left_input_halo
is
None
:
...
@@ -27,7 +30,7 @@ class HaloExchangerNoComm(HaloExchanger):
...
@@ -27,7 +30,7 @@ class HaloExchangerNoComm(HaloExchanger):
class
HaloExchangerAllGather
(
HaloExchanger
):
class
HaloExchangerAllGather
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerAllGather
,
self
).
__init__
()
super
(
HaloExchangerAllGather
,
self
).
__init__
(
spatial_group_size
,
rank
)
self
.
spatial_group_size
=
spatial_group_size
self
.
spatial_group_size
=
spatial_group_size
self
.
local_rank
=
rank
%
spatial_group_size
self
.
local_rank
=
rank
%
spatial_group_size
self
.
comm
=
comm
self
.
comm
=
comm
...
@@ -43,14 +46,24 @@ class HaloExchangerAllGather(HaloExchanger):
...
@@ -43,14 +46,24 @@ class HaloExchangerAllGather(HaloExchanger):
ag_left_input_halo
=
all_halos
[(
self
.
spatial_group_size
+
self
.
local_rank
-
1
)
%
self
.
spatial_group_size
][:,
Hh
:,:,:]
ag_left_input_halo
=
all_halos
[(
self
.
spatial_group_size
+
self
.
local_rank
-
1
)
%
self
.
spatial_group_size
][:,
Hh
:,:,:]
ag_right_input_halo
=
all_halos
[(
self
.
local_rank
+
1
)
%
self
.
spatial_group_size
][:,:
Hh
,:,:]
ag_right_input_halo
=
all_halos
[(
self
.
local_rank
+
1
)
%
self
.
spatial_group_size
][:,:
Hh
,:,:]
if
left_input_halo
is
None
:
if
left_input_halo
is
None
:
if
self
.
left_zero
:
ag_left_input_halo
.
zero_
()
if
self
.
right_zero
:
ag_right_input_halo
.
zero_
()
return
ag_left_input_halo
,
ag_right_input_halo
return
ag_left_input_halo
,
ag_right_input_halo
else
:
if
self
.
left_zero
:
left_input_halo
.
zero_
()
else
:
else
:
left_input_halo
.
copy_
(
ag_left_input_halo
)
left_input_halo
.
copy_
(
ag_left_input_halo
)
if
self
.
right_zero
:
right_input_halo
.
zero_
()
else
:
right_input_halo
.
copy_
(
ag_right_input_halo
)
right_input_halo
.
copy_
(
ag_right_input_halo
)
class
HaloExchangerSendRecv
(
HaloExchanger
):
class
HaloExchangerSendRecv
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerSendRecv
,
self
).
__init__
()
super
(
HaloExchangerSendRecv
,
self
).
__init__
(
spatial_group_size
,
rank
)
self
.
world_size
=
world_size
self
.
world_size
=
world_size
self
.
spatial_group_size
=
spatial_group_size
self
.
spatial_group_size
=
spatial_group_size
nccl_id
=
inc
.
get_unique_nccl_id
(
1
).
cuda
()
nccl_id
=
inc
.
get_unique_nccl_id
(
1
).
cuda
()
...
@@ -60,14 +73,14 @@ class HaloExchangerSendRecv(HaloExchanger):
...
@@ -60,14 +73,14 @@ class HaloExchangerSendRecv(HaloExchanger):
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
,
left_input_halo
=
None
,
right_input_halo
=
None
):
if
left_input_halo
is
None
:
if
left_input_halo
is
None
:
left_input_halo
,
right_input_halo
=
inc
.
left_right_halo_exchange
(
self
.
handle
,
left_output_halo
,
right_output_halo
,
self
.
spatial_group_size
)
left_input_halo
,
right_input_halo
=
inc
.
left_right_halo_exchange
(
self
.
handle
,
self
.
left_zero
,
self
.
right_zero
,
left_output_halo
,
right_output_halo
,
self
.
spatial_group_size
)
return
left_input_halo
,
right_input_halo
return
left_input_halo
,
right_input_halo
else
:
else
:
inc
.
left_right_halo_exchange_inplace
(
self
.
handle
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
self
.
spatial_group_size
)
inc
.
left_right_halo_exchange_inplace
(
self
.
handle
,
self
.
left_zero
,
self
.
right_zero
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
self
.
spatial_group_size
)
class
HaloExchangerPeer
(
HaloExchanger
):
class
HaloExchangerPeer
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
,
peer_pool
,
explicit_nhwc
,
numSM
=
1
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
,
peer_pool
,
explicit_nhwc
,
numSM
=
1
):
super
(
HaloExchangerPeer
,
self
).
__init__
()
super
(
HaloExchangerPeer
,
self
).
__init__
(
spatial_group_size
,
rank
)
self
.
diagnostics
=
False
self
.
diagnostics
=
False
self
.
spatial_group_size
=
spatial_group_size
self
.
spatial_group_size
=
spatial_group_size
self
.
peer_rank
=
rank
%
spatial_group_size
self
.
peer_rank
=
rank
%
spatial_group_size
...
@@ -93,6 +106,11 @@ class HaloExchangerPeer(HaloExchanger):
...
@@ -93,6 +106,11 @@ class HaloExchangerPeer(HaloExchanger):
right_output_halo
,
right_tx
[
self
.
peer_rank
],
left_tx
[
self
.
right_neighbor
],
right_input_halo
,
right_output_halo
,
right_tx
[
self
.
peer_rank
],
left_tx
[
self
.
right_neighbor
],
right_input_halo
,
self
.
signals
[
self
.
left_neighbor
],
self
.
signals
[
self
.
right_neighbor
],
self
.
signals
[
self
.
peer_rank
]
self
.
signals
[
self
.
left_neighbor
],
self
.
signals
[
self
.
right_neighbor
],
self
.
signals
[
self
.
peer_rank
]
)
)
# TODO: Add to push_pull_halos_1d kernel
if
self
.
left_zero
:
left_input_halo
.
zero_
()
if
self
.
right_zero
:
right_input_halo
.
zero_
()
if
not
inplace
:
if
not
inplace
:
return
left_input_halo
,
right_input_halo
return
left_input_halo
,
right_input_halo
...
...
apex/contrib/csrc/bottleneck/bottleneck.cpp
View file @
34df0f79
...
@@ -1620,6 +1620,7 @@ struct bottleneck_forward_status {
...
@@ -1620,6 +1620,7 @@ struct bottleneck_forward_status {
int64_t
outdimA0
[
4
];
int64_t
outdimA0
[
4
];
int64_t
outdimA1
[
4
];
int64_t
outdimA1
[
4
];
int64_t
outdimA1b
[
4
];
// out1_pad
int64_t
outdimA2
[
4
];
int64_t
outdimA2
[
4
];
int64_t
outdimA3
[
4
];
int64_t
outdimA3
[
4
];
int64_t
outdimA4
[
4
];
int64_t
outdimA4
[
4
];
...
@@ -1633,6 +1634,7 @@ struct bottleneck_forward_status {
...
@@ -1633,6 +1634,7 @@ struct bottleneck_forward_status {
int64_t
outdim0
[
4
];
// halo input shape
int64_t
outdim0
[
4
];
// halo input shape
int64_t
outdim1
[
4
];
int64_t
outdim1
[
4
];
int64_t
outdim1b
[
4
];
int64_t
outdim2
[
4
];
int64_t
outdim2
[
4
];
int64_t
outdim3
[
4
];
int64_t
outdim3
[
4
];
int64_t
outdim4
[
4
];
// halo output shape
int64_t
outdim4
[
4
];
// halo output shape
...
@@ -1672,6 +1674,7 @@ struct bottleneck_forward_status {
...
@@ -1672,6 +1674,7 @@ struct bottleneck_forward_status {
// output dim in n,c,h,w used by backend
// output dim in n,c,h,w used by backend
outdimA0
[
0
]
=
outdimA0
[
1
]
=
outdimA0
[
2
]
=
outdimA0
[
3
]
=
0
;
outdimA0
[
0
]
=
outdimA0
[
1
]
=
outdimA0
[
2
]
=
outdimA0
[
3
]
=
0
;
outdimA1
[
0
]
=
outdimA1
[
1
]
=
outdimA1
[
2
]
=
outdimA1
[
3
]
=
0
;
outdimA1
[
0
]
=
outdimA1
[
1
]
=
outdimA1
[
2
]
=
outdimA1
[
3
]
=
0
;
outdimA1b
[
0
]
=
outdimA1b
[
1
]
=
outdimA1b
[
2
]
=
outdimA1b
[
3
]
=
0
;
outdimA2
[
0
]
=
outdimA2
[
1
]
=
outdimA2
[
2
]
=
outdimA2
[
3
]
=
0
;
outdimA2
[
0
]
=
outdimA2
[
1
]
=
outdimA2
[
2
]
=
outdimA2
[
3
]
=
0
;
outdimA3
[
0
]
=
outdimA3
[
1
]
=
outdimA3
[
2
]
=
outdimA3
[
3
]
=
0
;
outdimA3
[
0
]
=
outdimA3
[
1
]
=
outdimA3
[
2
]
=
outdimA3
[
3
]
=
0
;
outdimA4
[
0
]
=
outdimA4
[
1
]
=
outdimA4
[
2
]
=
outdimA4
[
3
]
=
0
;
outdimA4
[
0
]
=
outdimA4
[
1
]
=
outdimA4
[
2
]
=
outdimA4
[
3
]
=
0
;
...
@@ -1690,6 +1693,13 @@ struct bottleneck_forward_status {
...
@@ -1690,6 +1693,13 @@ struct bottleneck_forward_status {
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA1
[
dim
+
2
]
=
getFwdConvOutputDim
(
dimA
[
dim
+
2
],
padA
[
dim
],
filterdimA1
[
dim
+
2
],
convstride1X1
[
dim
],
dilationA
[
dim
]);
outdimA1
[
dim
+
2
]
=
getFwdConvOutputDim
(
dimA
[
dim
+
2
],
padA
[
dim
],
filterdimA1
[
dim
+
2
],
convstride1X1
[
dim
],
dilationA
[
dim
]);
}
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
if
(
dim
==
2
)
{
outdimA1b
[
dim
]
=
outdimA1
[
dim
]
+
2
;
}
else
{
outdimA1b
[
dim
]
=
outdimA1
[
dim
];
}
}
outdimA2
[
0
]
=
outdimA1
[
0
];
outdimA2
[
0
]
=
outdimA1
[
0
];
outdimA2
[
1
]
=
filterdimA2
[
0
];
outdimA2
[
1
]
=
filterdimA2
[
0
];
...
@@ -1715,6 +1725,7 @@ struct bottleneck_forward_status {
...
@@ -1715,6 +1725,7 @@ struct bottleneck_forward_status {
// Create output tensor in the correct shape in pytorch's view
// Create output tensor in the correct shape in pytorch's view
outdim1
[
0
]
=
outdim1
[
1
]
=
outdim1
[
2
]
=
outdim1
[
3
]
=
0
;
outdim1
[
0
]
=
outdim1
[
1
]
=
outdim1
[
2
]
=
outdim1
[
3
]
=
0
;
outdim1b
[
0
]
=
outdim1b
[
1
]
=
outdim1b
[
2
]
=
outdim1b
[
3
]
=
0
;
outdim2
[
0
]
=
outdim2
[
1
]
=
outdim2
[
2
]
=
outdim2
[
3
]
=
0
;
outdim2
[
0
]
=
outdim2
[
1
]
=
outdim2
[
2
]
=
outdim2
[
3
]
=
0
;
outdim3
[
0
]
=
outdim3
[
1
]
=
outdim3
[
2
]
=
outdim3
[
3
]
=
0
;
outdim3
[
0
]
=
outdim3
[
1
]
=
outdim3
[
2
]
=
outdim3
[
3
]
=
0
;
if
(
explicit_nhwc
)
{
if
(
explicit_nhwc
)
{
...
@@ -1726,6 +1737,7 @@ struct bottleneck_forward_status {
...
@@ -1726,6 +1737,7 @@ struct bottleneck_forward_status {
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
outdim0
[
dim
]
=
outdimA0
[
axis
[
dim
]];
outdim0
[
dim
]
=
outdimA0
[
axis
[
dim
]];
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim1b
[
dim
]
=
outdimA1b
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
outdim4
[
dim
]
=
outdimA4
[
axis
[
dim
]];
outdim4
[
dim
]
=
outdimA4
[
axis
[
dim
]];
...
@@ -1859,6 +1871,44 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
...
@@ -1859,6 +1871,44 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
DEBUG_MSG
(
"[DEBUG] new relu2 : "
<<
out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new relu2 : "
<<
out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
}
void
bottleneck_forward_out2_pad
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
out1_pad
)
{
std
::
cout
<<
std
::
fixed
;
// from _out1 method
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
auto
out1
=
outputs
[
0
];
at
::
Half
*
y1
=
out1_pad
.
data_ptr
<
at
::
Half
>
();
// run
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
8
].
data_ptr
<
at
::
Half
>
();
auto
out2
=
outputs
[
1
];
at
::
Half
*
y2
=
out2
.
data_ptr
<
at
::
Half
>
();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation
(
forward_state
.
outdimA1b
,
forward_state
.
padA2
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA2
,
forward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
y1
,
w
,
y2
,
z
,
b
,
nullptr
);
DEBUG_MSG
(
"[DEBUG] new relu2 : "
<<
out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
void
bottleneck_forward_rest
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
void
bottleneck_forward_rest
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
std
::
cout
<<
std
::
fixed
;
std
::
cout
<<
std
::
fixed
;
...
@@ -2520,6 +2570,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -2520,6 +2570,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"forward_out1"
,
&
bottleneck_forward_out1
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out1"
,
&
bottleneck_forward_out1
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2"
,
&
bottleneck_forward_out2
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2"
,
&
bottleneck_forward_out2
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_halo"
,
&
bottleneck_forward_out2_halo
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_halo"
,
&
bottleneck_forward_out2_halo
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_pad"
,
&
bottleneck_forward_out2_pad
,
"Bottleneck block forward"
);
m
.
def
(
"forward_rest"
,
&
bottleneck_forward_rest
,
"Bottleneck block forward"
);
m
.
def
(
"forward_rest"
,
&
bottleneck_forward_rest
,
"Bottleneck block forward"
);
m
.
def
(
"backward_init"
,
&
bottleneck_backward_init
,
"Bottleneck block backward init"
);
m
.
def
(
"backward_init"
,
&
bottleneck_backward_init
,
"Bottleneck block backward init"
);
m
.
def
(
"backward_grad_out2"
,
&
bottleneck_backward_grad_out2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_grad_out2"
,
&
bottleneck_backward_grad_out2
,
"Bottleneck block backward"
);
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
View file @
34df0f79
...
@@ -100,7 +100,7 @@ class NcclCommWrapper
...
@@ -100,7 +100,7 @@ class NcclCommWrapper
});
});
}
}
void
left_right_halo_exchange_inplace
(
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
,
int
group_size
)
void
left_right_halo_exchange_inplace
(
bool
left_zero
,
bool
right_zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
,
int
group_size
)
{
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
ncclGroupStart
();
ncclGroupStart
();
...
@@ -132,16 +132,18 @@ class NcclCommWrapper
...
@@ -132,16 +132,18 @@ class NcclCommWrapper
});
});
}
}
ncclGroupEnd
();
ncclGroupEnd
();
if
(
left_zero
)
left_input_halo
.
zero_
();
if
(
right_zero
)
right_input_halo
.
zero_
();
}
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
)
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
bool
left_zero
,
bool
right_zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
)
{
{
// after halo exchange:
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto
right_input_halo
=
torch
::
empty_like
(
left_output_halo
);
auto
right_input_halo
=
torch
::
empty_like
(
left_output_halo
);
auto
left_input_halo
=
torch
::
empty_like
(
right_output_halo
);
auto
left_input_halo
=
torch
::
empty_like
(
right_output_halo
);
left_right_halo_exchange_inplace
(
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
group_size
);
left_right_halo_exchange_inplace
(
left_zero
,
right_zero
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
group_size
);
return
{
left_input_halo
,
right_input_halo
};
return
{
left_input_halo
,
right_input_halo
};
}
}
};
};
...
@@ -195,18 +197,18 @@ void nccl_recv(int handle, at::Tensor input, int sender)
...
@@ -195,18 +197,18 @@ void nccl_recv(int handle, at::Tensor input, int sender)
communicator
.
recv
(
input
,
sender
);
communicator
.
recv
(
input
,
sender
);
}
}
void
left_right_halo_exchange_inplace
(
int
handle
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
,
int
group_size
)
void
left_right_halo_exchange_inplace
(
int
handle
,
bool
left_zero
,
bool
right_zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
,
int
group_size
)
{
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
&
communicator
=
nccl_comms
[
handle
];
class
NcclCommWrapper
&
communicator
=
nccl_comms
[
handle
];
return
communicator
.
left_right_halo_exchange_inplace
(
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
group_size
);
return
communicator
.
left_right_halo_exchange_inplace
(
left_zero
,
right_zero
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
,
group_size
);
}
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
)
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
bool
left_zero
,
bool
right_zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
)
{
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
&
communicator
=
nccl_comms
[
handle
];
class
NcclCommWrapper
&
communicator
=
nccl_comms
[
handle
];
return
communicator
.
left_right_halo_exchange
(
left_output_halo
,
right_output_halo
,
group_size
);
return
communicator
.
left_right_halo_exchange
(
left_zero
,
right_zero
,
left_output_halo
,
right_output_halo
,
group_size
);
}
}
void
add_delay
(
int
delay
)
void
add_delay
(
int
delay
)
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
View file @
34df0f79
...
@@ -38,6 +38,8 @@ void nccl_recv(
...
@@ -38,6 +38,8 @@ void nccl_recv(
);
);
void
left_right_halo_exchange_inplace
(
void
left_right_halo_exchange_inplace
(
int
handle
,
int
handle
,
bool
left_zero
,
bool
right_zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
left_input_halo
,
...
@@ -45,6 +47,8 @@ void left_right_halo_exchange_inplace(
...
@@ -45,6 +47,8 @@ void left_right_halo_exchange_inplace(
int
group_size
);
int
group_size
);
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
int
handle
,
bool
left_zero
,
bool
right_zero
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
int
group_size
...
...
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
View file @
34df0f79
...
@@ -239,6 +239,7 @@ int64_t allocate_raw(int64_t size)
...
@@ -239,6 +239,7 @@ int64_t allocate_raw(int64_t size)
{
{
float
*
ptr
=
0L
;
float
*
ptr
=
0L
;
cudaMalloc
(
&
ptr
,
size
);
cudaMalloc
(
&
ptr
,
size
);
cudaMemset
(
ptr
,
0
,
size
);
return
(
int64_t
)
ptr
;
return
(
int64_t
)
ptr
;
}
}
...
...
apex/contrib/peer_memory/peer_halo_exchange_module_tests.py
View file @
34df0f79
...
@@ -53,7 +53,7 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
...
@@ -53,7 +53,7 @@ def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_spli
btm_inp_halo
.
copy_
(
top_inp_halos
[
btm_rank
])
btm_inp_halo
.
copy_
(
top_inp_halos
[
btm_rank
])
def
single_test
(
peer_rank
,
peer_group_size
,
halo_ex
,
C
,
H
,
W
,
half_halo
,
dtype
,
memory_format
,
H_split
,
numSM
=
1
):
def
single_test
(
peer_rank
,
peer_group_size
,
halo_ex
,
C
,
H
,
W
,
half_halo
,
dtype
,
memory_format
,
H_split
,
num_steps
,
numSM
=
1
):
if
memory_format
==
1
:
if
memory_format
==
1
:
# 1 -> explicit nhwc
# 1 -> explicit nhwc
explicit_nhwc
=
True
explicit_nhwc
=
True
...
@@ -77,10 +77,23 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype,
...
@@ -77,10 +77,23 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype,
if
memory_format
==
2
:
if
memory_format
==
2
:
y
=
y
.
to
(
memory_format
=
torch
.
channels_last
)
y
=
y
.
to
(
memory_format
=
torch
.
channels_last
)
ym
=
y
[:,:,:,
half_halo
:
W
+
half_halo
]
ym
=
y
[:,:,:,
half_halo
:
W
+
half_halo
]
y2
=
y
.
clone
()
y3
=
y
.
clone
()
list_y
=
[]
for
step
in
range
(
num_steps
):
halo_ex
(
y
,
H_split
,
explicit_nhwc
,
numSM
)
halo_ex
(
y
,
H_split
,
explicit_nhwc
,
numSM
)
list_y
.
append
(
y
.
clone
())
y
.
copy_
(
y3
)
halo_ex
.
peer_pool
.
reset
()
torch
.
distributed
.
barrier
()
y2
=
y3
.
clone
()
list_y2
=
[]
for
step
in
range
(
num_steps
):
nccl_halo_ex
(
peer_rank
,
peer_group_size
,
y2
,
half_halo
,
explicit_nhwc
,
H_split
)
nccl_halo_ex
(
peer_rank
,
peer_group_size
,
y2
,
half_halo
,
explicit_nhwc
,
H_split
)
is_equal
=
torch
.
all
(
torch
.
eq
(
y
,
y2
))
list_y2
.
append
(
y2
.
clone
())
y2
.
copy_
(
y3
)
is_equal
=
[
torch
.
all
(
torch
.
eq
(
yy
,
yy2
))
for
yy
,
yy2
in
zip
(
list_y
,
list_y2
)]
is_equal
=
torch
.
tensor
(
is_equal
,
dtype
=
torch
.
bool
)
is_equal
=
torch
.
all
(
is_equal
)
if
peer_rank
==
0
:
if
peer_rank
==
0
:
if
memory_format
==
1
:
if
memory_format
==
1
:
memory_format_str
=
"explicit_nhwc"
memory_format_str
=
"explicit_nhwc"
...
@@ -99,26 +112,26 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype,
...
@@ -99,26 +112,26 @@ def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype,
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
def
H_split_tests
(
N
,
C
,
H
,
W
,
half_halo
,
rank
,
world_size
,
halo_ex
):
def
H_split_tests
(
N
,
C
,
H
,
W
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
):
Hr
=
8
*
world_size
Hr
=
8
*
world_size
Hp
=
((
H
+
Hr
-
1
)
//
Hr
)
*
8
Hp
=
((
H
+
Hr
-
1
)
//
Hr
)
*
8
for
i
in
range
(
4
):
for
i
in
range
(
4
):
div
=
int
(
pow
(
2
,
i
))
div
=
int
(
pow
(
2
,
i
))
single_test
(
rank
,
world_size
,
halo_ex
,
C
*
div
,
Hp
//
div
,
W
//
div
,
half_halo
,
torch
.
float16
,
1
,
True
)
single_test
(
rank
,
world_size
,
halo_ex
,
C
*
div
,
Hp
//
div
,
W
//
div
,
half_halo
,
torch
.
float16
,
1
,
True
,
num_steps
)
single_test
(
rank
,
world_size
,
halo_ex
,
C
*
div
,
Hp
//
div
,
W
//
div
,
half_halo
,
torch
.
float16
,
2
,
True
)
single_test
(
rank
,
world_size
,
halo_ex
,
C
*
div
,
Hp
//
div
,
W
//
div
,
half_halo
,
torch
.
float16
,
2
,
True
,
num_steps
)
single_test
(
rank
,
world_size
,
halo_ex
,
C
*
div
,
Hp
//
div
,
W
//
div
,
half_halo
,
torch
.
float16
,
3
,
True
)
single_test
(
rank
,
world_size
,
halo_ex
,
C
*
div
,
Hp
//
div
,
W
//
div
,
half_halo
,
torch
.
float16
,
3
,
True
,
num_steps
)
def
W_split_tests
(
N
,
C
,
H
,
W
,
half_halo
,
rank
,
world_size
,
halo_ex
):
def
W_split_tests
(
N
,
C
,
H
,
W
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
):
Wr
=
8
*
world_size
Wr
=
8
*
world_size
Wp
=
((
W
+
Wr
-
1
)
//
Wr
)
*
8
Wp
=
((
W
+
Wr
-
1
)
//
Wr
)
*
8
for
i
in
range
(
4
):
for
i
in
range
(
4
):
div
=
int
(
pow
(
2
,
i
))
div
=
int
(
pow
(
2
,
i
))
single_test
(
rank
,
world_size
,
halo_ex
,
C
*
div
,
H
//
div
,
Wp
//
div
,
half_halo
,
torch
.
float16
,
1
,
False
)
single_test
(
rank
,
world_size
,
halo_ex
,
C
*
div
,
H
//
div
,
Wp
//
div
,
half_halo
,
torch
.
float16
,
1
,
False
,
num_steps
)
single_test
(
rank
,
world_size
,
halo_ex
,
C
*
div
,
H
//
div
,
Wp
//
div
,
half_halo
,
torch
.
float16
,
2
,
False
)
single_test
(
rank
,
world_size
,
halo_ex
,
C
*
div
,
H
//
div
,
Wp
//
div
,
half_halo
,
torch
.
float16
,
2
,
False
,
num_steps
)
single_test
(
rank
,
world_size
,
halo_ex
,
C
*
div
,
H
//
div
,
Wp
//
div
,
half_halo
,
torch
.
float16
,
3
,
False
)
single_test
(
rank
,
world_size
,
halo_ex
,
C
*
div
,
H
//
div
,
Wp
//
div
,
half_halo
,
torch
.
float16
,
3
,
False
,
num_steps
)
def
main
():
def
main
():
...
@@ -130,11 +143,13 @@ def main():
...
@@ -130,11 +143,13 @@ def main():
torch
.
cuda
.
set_device
(
rank
)
torch
.
cuda
.
set_device
(
rank
)
pool
=
PeerMemoryPool
(
rank
,
world_size
,
world_size
,
64
*
1024
,
2
*
1024
*
1024
)
pool
=
PeerMemoryPool
(
rank
,
world_size
,
world_size
,
64
*
1024
,
2
*
1024
*
1024
)
num_steps
=
100
half_halo
=
1
half_halo
=
1
halo_ex
=
PeerHaloExchanger1d
(
rank
,
world_size
,
pool
,
half_halo
)
halo_ex
=
PeerHaloExchanger1d
(
rank
,
world_size
,
pool
,
half_halo
)
H_split_tests
(
1
,
64
,
336
,
200
,
half_halo
,
rank
,
world_size
,
halo_ex
)
H_split_tests
(
1
,
64
,
336
,
200
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
)
W_split_tests
(
1
,
64
,
200
,
336
,
half_halo
,
rank
,
world_size
,
halo_ex
)
W_split_tests
(
1
,
64
,
200
,
336
,
half_halo
,
rank
,
world_size
,
halo_ex
,
num_steps
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment