Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
ed713c84
Unverified
Commit
ed713c84
authored
Aug 31, 2021
by
Thor Johnsen
Committed by
GitHub
Aug 31, 2021
Browse files
Merge pull request #1151 from NVIDIA/spatial_fast_bottleneck
Spatially Distributed Fast Bottleneck block
parents
d6b5ae5d
bbc95c0a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1151 additions
and
1 deletion
+1151
-1
apex/contrib/bottleneck/__init__.py
apex/contrib/bottleneck/__init__.py
+1
-1
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+233
-0
apex/contrib/bottleneck/bottleneck_module_test.py
apex/contrib/bottleneck/bottleneck_module_test.py
+198
-0
apex/contrib/csrc/bottleneck/bottleneck.cpp
apex/contrib/csrc/bottleneck/bottleneck.cpp
+719
-0
No files found.
apex/contrib/bottleneck/__init__.py
View file @
ed713c84
from
.bottleneck
import
Bottleneck
from
.bottleneck
import
Bottleneck
,
SpatialBottleneck
apex/contrib/bottleneck/bottleneck.py
View file @
ed713c84
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
import
fast_bottleneck
...
...
@@ -212,3 +213,235 @@ class Bottleneck(torch.nn.Module):
out
=
self
.
relu
(
out
)
return
out
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
spatial_group_size
,
local_rank
,
comm
,
stream1
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
# TODO: clean up order of tensors
args
=
[
x
,
*
conv
[
0
:
3
],
*
scale
[
0
:
3
],
*
bias
[
0
:
3
]]
ctx
.
downsample
=
len
(
conv
)
>
3
if
ctx
.
downsample
:
args
.
append
(
conv
[
3
])
args
.
append
(
scale
[
3
])
args
.
append
(
bias
[
3
])
# weight buffers are always in nhwc while shape can be nhwc or channels_last
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
outputs
=
fast_bottleneck
.
forward_init
(
nhwc
,
stride_1x1
,
args
)
fast_bottleneck
.
forward_out1
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
# do halo exchange for outputs[0] (out1)
if
spatial_group_size
>
1
:
out1
=
outputs
[
0
]
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
padded_out1
=
torch
.
empty
((
N
,
Hs
+
2
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
padded_out1
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream1
):
# copy halos to send buffer
send_halos
=
torch
.
empty
((
N
,
2
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
send_halos
[:,:
1
,:,:].
copy_
(
out1
[:,:
1
,:,:])
send_halos
[:,
1
:,:,:].
copy_
(
out1
[:,
Hs
-
1
:,:,:])
all_halos
=
torch
.
empty
((
N
,
2
*
spatial_group_size
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
:(
i
+
1
)
*
2
,:,:]
for
i
in
range
(
spatial_group_size
)]
dist
.
all_gather
(
all_halos
,
send_halos
)
padded_out1_top_halo
=
padded_out1
[:,:
1
,:,:]
if
local_rank
>
0
:
top_halo
=
all_halos
[
local_rank
-
1
][:,
1
:,:,:]
padded_out1_top_halo
.
copy_
(
top_halo
)
fat_top_halo
=
padded_out1
[:,:
3
,:,:]
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_top_halo
,
args
)
else
:
padded_out1_top_halo
.
zero_
()
padded_out1_btm_halo
=
padded_out1
[:,
Hs
+
1
:,:,:]
if
local_rank
<
spatial_group_size
-
1
:
btm_halo
=
all_halos
[
local_rank
+
1
][:,:
1
,:,:]
padded_out1_btm_halo
.
copy_
(
btm_halo
)
fat_btm_halo
=
padded_out1
[:,
Hs
-
1
:,:,:]
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_btm_halo
,
args
)
else
:
padded_out1_btm_halo
.
zero_
()
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
out2
=
outputs
[
1
]
if
local_rank
>
0
:
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
if
local_rank
<
spatial_group_size
-
1
:
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
if
spatial_group_size
>
1
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
padded_out1
]))
else
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
# save relu outputs for drelu
ctx
.
nhwc
=
nhwc
ctx
.
stride_1x1
=
stride_1x1
ctx
.
spatial_group_size
=
spatial_group_size
ctx
.
local_rank
=
local_rank
ctx
.
comm
=
comm
ctx
.
stream1
=
stream1
return
outputs
[
2
]
# backward relu is not exposed, MUL with mask used now
# only support dgrad
@
staticmethod
def
backward
(
ctx
,
grad_o
):
if
ctx
.
spatial_group_size
>
1
:
outputs
=
ctx
.
saved_tensors
[
-
4
:
-
1
]
else
:
outputs
=
ctx
.
saved_tensors
[
-
3
:]
if
ctx
.
downsample
:
grad_conv3
,
grad_conv4
=
drelu_dscale2
(
grad_o
,
outputs
[
2
],
ctx
.
saved_tensors
[
6
],
ctx
.
saved_tensors
[
11
])
else
:
grad_conv3
,
grad_conv4
=
drelu_dscale1
(
grad_o
,
outputs
[
2
],
ctx
.
saved_tensors
[
6
])
# create input vector for backward
t_list
=
[
*
ctx
.
saved_tensors
[
0
:
10
]]
t_list
.
append
(
grad_conv3
)
t_list
.
append
(
grad_conv4
)
# outputs used for wgrad and generating drelu mask
t_list
.
append
(
outputs
[
0
])
t_list
.
append
(
outputs
[
1
])
# in case there is downsample
if
ctx
.
downsample
:
t_list
.
append
(
ctx
.
saved_tensors
[
10
])
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
# do halo exchange of grad_out2 here
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
spatial_bottleneck_function
=
SpatialBottleneckFunction
.
apply
class
SpatialBottleneck
(
torch
.
nn
.
Module
):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
# here we put it at 1x1
def
__init__
(
self
,
in_channels
,
bottleneck_channels
,
out_channels
,
stride
=
1
,
groups
=
1
,
dilation
=
1
,
norm_func
=
None
,
use_cudnn
=
False
,
explicit_nhwc
=
False
,
spatial_group_size
=
1
):
super
(
SpatialBottleneck
,
self
).
__init__
()
if
groups
!=
1
:
raise
RuntimeError
(
'Only support groups == 1'
)
if
dilation
!=
1
:
raise
RuntimeError
(
'Only support dilation == 1'
)
if
norm_func
==
None
:
norm_func
=
FrozenBatchNorm2d
else
:
raise
RuntimeError
(
'Only support frozen BN now.'
)
if
stride
!=
1
or
in_channels
!=
out_channels
:
self
.
downsample
=
nn
.
Sequential
(
conv1x1
(
in_channels
,
out_channels
,
stride
),
norm_func
(
out_channels
),
)
else
:
self
.
downsample
=
None
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self
.
conv1
=
conv1x1
(
in_channels
,
bottleneck_channels
,
stride
)
self
.
conv2
=
conv3x3
(
bottleneck_channels
,
bottleneck_channels
)
self
.
conv3
=
conv1x1
(
bottleneck_channels
,
out_channels
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
stride
=
stride
self
.
bn1
=
norm_func
(
bottleneck_channels
)
self
.
bn2
=
norm_func
(
bottleneck_channels
)
self
.
bn3
=
norm_func
(
out_channels
)
self
.
use_cudnn
=
use_cudnn
# setup conv weights
self
.
w_conv
=
[
self
.
conv1
.
weight
,
self
.
conv2
.
weight
,
self
.
conv3
.
weight
]
if
self
.
downsample
is
not
None
:
self
.
w_conv
.
append
(
self
.
downsample
[
0
].
weight
)
# init weight in nchw format before possible transpose
for
w
in
self
.
w_conv
:
kaiming_uniform_
(
w
,
a
=
1
)
# TODO: prevent unsupported case usage
# support cases
# native cudnn
# normal yes no
# channel_last yes yes
# explicit_nhwc no yes
self
.
explicit_nhwc
=
explicit_nhwc
if
self
.
explicit_nhwc
:
for
p
in
self
.
parameters
():
with
torch
.
no_grad
():
p
.
data
=
p
.
data
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
# spatial communicator
self
.
spatial_group_size
=
spatial_group_size
if
spatial_group_size
>
1
:
world_size
=
dist
.
get_world_size
()
num_groups
=
world_size
//
spatial_group_size
assert
(
num_groups
*
spatial_group_size
==
world_size
),
"torch.distributed.get_world_size() must be multiple of group_size"
rank
=
dist
.
get_rank
()
self
.
local_rank
=
rank
%
spatial_group_size
for
group
in
range
(
num_groups
):
ranks
=
list
(
range
(
group
*
spatial_group_size
,(
group
+
1
)
*
spatial_group_size
))
comm
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
rank
in
ranks
:
self
.
communicator
=
comm
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
spatial_args
=
self
.
spatial_group_size
,
self
.
local_rank
,
self
.
communicator
,
self
.
stream1
else
:
self
.
spatial_args
=
1
,
0
,
None
,
None
return
def
forward
(
self
,
x
):
if
self
.
use_cudnn
:
# calculate scale/bias from registered buffers
# TODO: make this better
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
s2
,
b2
=
self
.
bn2
.
get_scale_bias
(
self
.
explicit_nhwc
)
s3
,
b3
=
self
.
bn3
.
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
=
[
s1
,
s2
,
s3
]
w_bias
=
[
b1
,
b2
,
b3
]
if
self
.
downsample
is
not
None
:
s4
,
b4
=
self
.
downsample
[
1
].
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
return
out
if
self
.
explicit_nhwc
:
raise
RuntimeError
(
'explicit nhwc with native ops is not supported.'
)
# fallback to native ops
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
apex/contrib/bottleneck/bottleneck_module_test.py
0 → 100644
View file @
ed713c84
import
os
import
torch
from
maskrcnn_benchmark.modeling.backbone.resnet
import
Bottleneck
from
maskrcnn_benchmark.layers.nhwc
import
nhwc_to_nchw_transform
,
nchw_to_nhwc_transform
from
maskrcnn_benchmark.layers.nhwc.batch_norm
import
FrozenBatchNorm2d_NHWC
from
apex.contrib.bottleneck
import
Bottleneck
as
FastBottleneck
def
single_module_test
(
ref
,
rank
,
world_size
,
numtype
,
device
,
shape
,
fast
,
spatial_group_size
,
in_channels
,
bottleneck_channels
,
out_channels
,
num_groups
,
stride_in_1x1
,
stride
,
dilation
,
norm_func
,
nhwc
):
# inputs + modules
with
torch
.
no_grad
():
input_shape
=
[
1
,
in_channels
]
+
list
(
shape
)
x
=
torch
.
randn
(
input_shape
,
dtype
=
numtype
,
device
=
device
)
if
nhwc
:
x
=
nchw_to_nhwc_transform
(
x
).
contiguous
()
x
.
requires_grad
=
True
print
(
x
.
shape
,
x
.
stride
())
#if spatial_group_size > 1:
# fast = False # hack so fast bottleneck can be run against distributed bottleneck
#if spatial_group_size == 1:
# fast = False
if
fast
:
bottleneck
=
FastBottleneck
(
in_channels
=
in_channels
,
bottleneck_channels
=
bottleneck_channels
,
out_channels
=
out_channels
,
stride
=
stride
,
dilation
=
dilation
,
explicit_nhwc
=
nhwc
,
use_cudnn
=
True
)
if
spatial_group_size
>
1
:
print
(
"WARNING! spatial_group_size ignored by FastBottleneck"
)
else
:
bottleneck
=
Bottleneck
(
in_channels
,
bottleneck_channels
,
out_channels
,
num_groups
,
stride_in_1x1
,
stride
,
dilation
,
norm_func
,
nhwc
,
spatial_group_size
)
bottleneck
=
bottleneck
.
to
(
dtype
=
numtype
,
device
=
device
)
weights
=
dict
(
bottleneck
.
named_parameters
())
if
ref
is
not
None
:
ref_x
,
_
,
ref_weights
=
ref
Hs
,
H
=
x
.
shape
[
1
],
ref_x
.
shape
[
1
]
assert
(
Hs
*
spatial_group_size
==
H
),
"Hs not a multiple of H"
ref_x
=
ref_x
[:,
rank
*
Hs
:(
rank
+
1
)
*
Hs
,:,:]
x
.
copy_
(
ref_x
)
assert
(
len
(
weights
)
==
len
(
ref_weights
)),
"Reference weights and weights don't match"
for
k
in
weights
.
keys
():
weights
[
k
].
copy_
(
ref_weights
[
k
])
# forward
out
=
bottleneck
(
x
)
# gradient output
with
torch
.
no_grad
():
grad_out
=
torch
.
randn_like
(
out
)
if
ref
is
not
None
:
_
,
ref_grad_out
,
_
=
ref
Hs
,
H
=
grad_out
.
shape
[
1
],
ref_grad_out
.
shape
[
1
]
assert
(
Hs
*
spatial_group_size
==
H
),
"Hs not a multiple of H"
ref_grad_out
=
ref_grad_out
[:,
rank
*
Hs
:(
rank
+
1
)
*
Hs
,:,:]
grad_out
.
copy_
(
ref_grad_out
)
# backward
out
.
backward
(
grad_out
)
with
torch
.
no_grad
():
dgrad
=
x
.
grad
.
detach
()
wgrad
=
{}
for
n
,
p
in
bottleneck
.
named_parameters
():
wgrad
[
n
]
=
p
.
grad
.
detach
()
if
world_size
>
1
:
if
spatial_group_size
==
1
:
# broadcast x, grad_out and weights from rank 0
with
torch
.
no_grad
():
torch
.
distributed
.
broadcast
(
x
,
0
)
torch
.
distributed
.
broadcast
(
grad_out
,
0
)
for
k
in
weights
.
keys
():
torch
.
distributed
.
broadcast
(
weights
[
k
],
0
)
else
:
# gather dgrad (x.grad), sum wgrad (weights)
N
,
Hs
,
W
,
C
=
dgrad
.
shape
H
=
Hs
*
spatial_group_size
dgrad_gathered
=
torch
.
empty
((
N
,
H
,
W
,
C
),
dtype
=
dgrad
.
dtype
,
device
=
dgrad
.
device
)
dgrad_tensors
=
[
dgrad_gathered
[:,
i
*
Hs
:(
i
+
1
)
*
Hs
,:,:]
for
i
in
range
(
spatial_group_size
)]
torch
.
distributed
.
all_gather
(
dgrad_tensors
,
dgrad
)
dgrad
=
dgrad_gathered
for
k
in
wgrad
.
keys
():
torch
.
distributed
.
all_reduce
(
wgrad
[
k
])
return
x
,
out
,
grad_out
,
weights
,
dgrad
,
wgrad
def
module_tests
(
rank
,
world_size
,
numtype
,
device
,
fast
,
spatial_group_sizes
,
init_args
):
r
=
[]
for
ia
in
init_args
:
shape
=
ia
[
0
:
4
]
args
=
ia
[
4
:]
rr
=
[]
ref
=
None
for
spatial_group_size
in
spatial_group_sizes
:
N
,
H
,
W
,
C
=
shape
H
=
H
//
spatial_group_size
x
,
out
,
grad_out
,
weights
,
dgrad
,
wgrad
=
single_module_test
(
ref
,
rank
,
world_size
,
numtype
,
device
,
[
H
,
W
],
fast
,
spatial_group_size
,
*
args
)
if
ref
is
None
:
assert
(
spatial_group_size
==
1
),
"Wrong reference weights"
ref
=
x
,
grad_out
,
weights
if
rank
==
0
:
rr
.
append
(
(
out
,
dgrad
,
wgrad
)
)
torch
.
distributed
.
barrier
()
r
.
append
(
rr
)
return
r
def
main
():
total_num_gpus
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
if
"WORLD_SIZE"
in
os
.
environ
else
1
distributed
=
total_num_gpus
>
1
ngpus
=
torch
.
cuda
.
device_count
()
if
distributed
:
torch
.
distributed
.
init_process_group
(
"nccl"
)
rank
,
world_size
=
torch
.
distributed
.
get_rank
(),
torch
.
distributed
.
get_world_size
()
is_master
=
True
if
rank
==
0
else
False
local_rank
=
rank
%
ngpus
torch
.
cuda
.
set_device
(
local_rank
)
spatial_group_size
=
total_num_gpus
else
:
rank
,
local_rank
,
is_master
,
world_size
,
spatial_group_size
=
0
,
0
,
True
,
1
,
1
#torch.use_deterministic_algorithms(True)
torch
.
backends
.
cudnn
.
benchmark
=
True
#torch.backends.cudnn.deterministic = True
#torch.backends.cuda.matmul.allow_tf32 = False
#torch.backends.cudnn.allow_tf32 = False
norm_func
=
FrozenBatchNorm2d_NHWC
init_args
=
[
(
1
,
200
,
336
,
64
,
64
,
64
,
256
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
200
,
336
,
256
,
256
,
64
,
256
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
200
,
336
,
256
,
256
,
128
,
512
,
1
,
True
,
2
,
1
,
norm_func
,
True
),
(
1
,
100
,
168
,
512
,
512
,
128
,
512
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
100
,
168
,
512
,
512
,
256
,
1024
,
1
,
True
,
2
,
1
,
norm_func
,
True
),
(
1
,
50
,
84
,
1024
,
1024
,
256
,
1024
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
50
,
84
,
1024
,
1024
,
512
,
2048
,
1
,
True
,
2
,
1
,
norm_func
,
True
),
(
1
,
25
,
42
,
2048
,
2048
,
512
,
2048
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
336
,
200
,
64
,
64
,
64
,
256
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
336
,
200
,
256
,
256
,
64
,
256
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
336
,
200
,
256
,
256
,
128
,
512
,
1
,
True
,
2
,
1
,
norm_func
,
True
),
(
1
,
168
,
100
,
512
,
512
,
128
,
512
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
168
,
100
,
512
,
512
,
256
,
1024
,
1
,
True
,
2
,
1
,
norm_func
,
True
),
(
1
,
84
,
50
,
1024
,
1024
,
256
,
1024
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
(
1
,
84
,
50
,
1024
,
1024
,
512
,
2048
,
1
,
True
,
2
,
1
,
norm_func
,
True
),
(
1
,
42
,
25
,
2048
,
2048
,
512
,
2048
,
1
,
True
,
1
,
1
,
norm_func
,
True
),
]
# pad H to account for spatial distribution
padded_init_args
=
[]
for
ia
in
init_args
:
N
,
H
,
W
,
C
=
ia
[
0
:
4
]
m
=
spatial_group_size
*
H
//
(
25
if
H
<
W
else
42
)
H
=
((
H
+
m
-
1
)
//
m
)
*
m
args
=
tuple
(
[
N
,
H
,
W
,
C
]
+
list
(
ia
[
4
:])
)
padded_init_args
.
append
(
args
)
init_args
=
padded_init_args
if
rank
==
0
:
for
ia
in
init_args
:
print
(
ia
)
spatial_group_sizes
=
[
1
]
if
spatial_group_size
>
1
:
spatial_group_sizes
.
append
(
spatial_group_size
)
numtype
,
device
,
fast
=
torch
.
float16
,
'cuda'
,
False
r
=
module_tests
(
rank
,
world_size
,
numtype
,
device
,
fast
,
spatial_group_sizes
,
init_args
)
torch
.
distributed
.
barrier
()
if
rank
==
0
:
for
rr
in
r
:
print
(
"***"
)
for
out
,
dgrad
,
wgrad
in
rr
:
gr
=
[(
"dgrad"
,
dgrad
.
norm
(
p
=
2
,
dtype
=
torch
.
float64
).
item
())]
+
[(
k
+
".wgrad"
,
wgrad
[
k
].
norm
(
p
=
2
,
dtype
=
torch
.
float64
).
item
())
for
k
in
wgrad
.
keys
()]
print
(
gr
)
torch
.
distributed
.
barrier
()
if
__name__
==
"__main__"
:
main
()
apex/contrib/csrc/bottleneck/bottleneck.cpp
View file @
ed713c84
...
...
@@ -1606,7 +1606,726 @@ std::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1,
return
outputs
;
}
namespace
{
struct
bottleneck_forward_status
{
int64_t
dimA
[
4
];
int64_t
filterdimA1
[
4
];
int64_t
filterdimA2
[
4
];
int64_t
filterdimA3
[
4
];
int64_t
filterdimA4
[
4
];
int
axis
[
4
];
int64_t
outdimA0
[
4
];
int64_t
outdimA1
[
4
];
int64_t
outdimA2
[
4
];
int64_t
outdimA3
[
4
];
int64_t
outdimA4
[
4
];
int64_t
padA
[
2
];
int64_t
padA1
[
2
];
int64_t
padA2
[
2
];
// halo padding
int64_t
dilationA
[
2
];
int64_t
convstrideA
[
2
];
int64_t
convstride1X1
[
2
];
int64_t
outdim0
[
4
];
// halo input shape
int64_t
outdim1
[
4
];
int64_t
outdim2
[
4
];
int64_t
outdim3
[
4
];
int64_t
outdim4
[
4
];
// halo output shape
void
init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
dimA
[
0
]
=
dimA
[
1
]
=
dimA
[
2
]
=
dimA
[
3
]
=
0
;
filterdimA1
[
0
]
=
filterdimA1
[
1
]
=
filterdimA1
[
2
]
=
filterdimA1
[
3
]
=
0
;
filterdimA2
[
0
]
=
filterdimA2
[
1
]
=
filterdimA2
[
2
]
=
filterdimA2
[
3
]
=
0
;
filterdimA3
[
0
]
=
filterdimA3
[
1
]
=
filterdimA3
[
2
]
=
filterdimA3
[
3
]
=
0
;
filterdimA4
[
0
]
=
filterdimA4
[
1
]
=
filterdimA4
[
2
]
=
filterdimA4
[
3
]
=
0
;
// All dim calculation after this order of n,c,h,w
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
3
;
axis
[
2
]
=
1
;
axis
[
3
]
=
2
;
}
else
{
axis
[
0
]
=
0
;
axis
[
1
]
=
1
;
axis
[
2
]
=
2
;
axis
[
3
]
=
3
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
dimA
[
dim
]
=
inputs
[
0
].
size
(
axis
[
dim
]);
filterdimA1
[
dim
]
=
inputs
[
1
].
size
(
axis
[
dim
]);
filterdimA2
[
dim
]
=
inputs
[
2
].
size
(
axis
[
dim
]);
filterdimA3
[
dim
]
=
inputs
[
3
].
size
(
axis
[
dim
]);
}
if
(
stride_1X1
!=
1
||
filterdimA3
[
0
]
!=
dimA
[
1
])
{
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
filterdimA4
[
dim
]
=
inputs
[
10
].
size
(
axis
[
dim
]);
}
}
// output dim in n,c,h,w used by backend
outdimA0
[
0
]
=
outdimA0
[
1
]
=
outdimA0
[
2
]
=
outdimA0
[
3
]
=
0
;
outdimA1
[
0
]
=
outdimA1
[
1
]
=
outdimA1
[
2
]
=
outdimA1
[
3
]
=
0
;
outdimA2
[
0
]
=
outdimA2
[
1
]
=
outdimA2
[
2
]
=
outdimA2
[
3
]
=
0
;
outdimA3
[
0
]
=
outdimA3
[
1
]
=
outdimA3
[
2
]
=
outdimA3
[
3
]
=
0
;
outdimA4
[
0
]
=
outdimA4
[
1
]
=
outdimA4
[
2
]
=
outdimA4
[
3
]
=
0
;
// use these fixed value for test run
padA
[
0
]
=
0
;
padA
[
1
]
=
0
;
padA1
[
0
]
=
1
;
padA1
[
1
]
=
1
;
padA2
[
0
]
=
0
;
padA2
[
1
]
=
1
;
dilationA
[
0
]
=
1
;
dilationA
[
1
]
=
1
;
convstrideA
[
0
]
=
1
;
convstrideA
[
1
]
=
1
;
convstride1X1
[
0
]
=
stride_1X1
;
convstride1X1
[
1
]
=
stride_1X1
;
// compute output from pad/stride/dilation
outdimA1
[
0
]
=
dimA
[
0
];
outdimA1
[
1
]
=
filterdimA1
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA1
[
dim
+
2
]
=
getFwdConvOutputDim
(
dimA
[
dim
+
2
],
padA
[
dim
],
filterdimA1
[
dim
+
2
],
convstride1X1
[
dim
],
dilationA
[
dim
]);
}
outdimA2
[
0
]
=
outdimA1
[
0
];
outdimA2
[
1
]
=
filterdimA2
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA2
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA1
[
dim
+
2
],
padA1
[
dim
],
filterdimA2
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
if
(
dim
==
2
)
{
outdimA0
[
dim
]
=
3
;
outdimA4
[
dim
]
=
1
;
}
else
{
outdimA0
[
dim
]
=
outdimA1
[
dim
];
outdimA4
[
dim
]
=
outdimA2
[
dim
];
}
}
outdimA3
[
0
]
=
outdimA2
[
0
];
outdimA3
[
1
]
=
filterdimA3
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA3
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA2
[
dim
+
2
],
padA
[
dim
],
filterdimA3
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
// Create output tensor in the correct shape in pytorch's view
outdim1
[
0
]
=
outdim1
[
1
]
=
outdim1
[
2
]
=
outdim1
[
3
]
=
0
;
outdim2
[
0
]
=
outdim2
[
1
]
=
outdim2
[
2
]
=
outdim2
[
3
]
=
0
;
outdim3
[
0
]
=
outdim3
[
1
]
=
outdim3
[
2
]
=
outdim3
[
3
]
=
0
;
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
2
;
axis
[
2
]
=
3
;
axis
[
3
]
=
1
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
outdim0
[
dim
]
=
outdimA0
[
axis
[
dim
]];
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
outdim4
[
dim
]
=
outdimA4
[
axis
[
dim
]];
}
}
};
bottleneck_forward_status
forward_state
;
}
// end of anonymous namespace
std
::
vector
<
at
::
Tensor
>
bottleneck_forward_init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
// NB! Bottleneck_forward and bottleneck_backward are NOT thread safe method.
// NB! We use a global object to store state.
forward_state
.
init
(
explicit_nhwc
,
stride_1X1
,
inputs
);
// create output vector
std
::
vector
<
at
::
Tensor
>
outputs
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
printf
(
"outdim1 = (%d,%d,%d,%d)
\n
"
,
forward_state
.
outdim1
[
0
],
forward_state
.
outdim1
[
1
],
forward_state
.
outdim1
[
2
],
forward_state
.
outdim1
[
3
]);
auto
out1
=
at
::
empty
(
forward_state
.
outdim1
,
inputs
[
0
].
type
(),
output_format
);
auto
out2
=
at
::
empty
(
forward_state
.
outdim2
,
inputs
[
0
].
type
(),
output_format
);
auto
out3
=
at
::
empty
(
forward_state
.
outdim3
,
inputs
[
0
].
type
(),
output_format
);
outputs
.
push_back
(
out1
);
outputs
.
push_back
(
out2
);
outputs
.
push_back
(
out3
);
return
outputs
;
}
// inputs contains x,w,z,b,(i)
void
bottleneck_forward_out1
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
std
::
cout
<<
std
::
fixed
;
// run
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
1
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
7
].
data_ptr
<
at
::
Half
>
();
auto
out1
=
outputs
[
0
];
at
::
Half
*
y1
=
out1
.
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias_add_activation
(
forward_state
.
dimA
,
forward_state
.
padA
,
forward_state
.
convstride1X1
,
forward_state
.
dilationA
,
forward_state
.
filterdimA1
,
forward_state
.
outdimA1
,
CUDNN_DATA_HALF
,
x
,
w
,
y1
,
z
,
b
,
nullptr
);
DEBUG_MSG
(
"[DEBUG] new relu1 : "
<<
out1
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
// computes halo (top or bottom) from fat halo input.
// fat halo input is 3 pixels wide in H.
at
::
Tensor
bottleneck_forward_out2_halo
(
bool
explicit_nhwc
,
at
::
Tensor
fat_halo_y1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// run
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
8
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
y1
=
fat_halo_y1
.
data_ptr
<
at
::
Half
>
();
auto
halo_y2
=
at
::
empty
(
forward_state
.
outdim4
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
y2
=
halo_y2
.
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias_add_activation
(
forward_state
.
outdimA0
,
forward_state
.
padA2
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA2
,
forward_state
.
outdimA4
,
CUDNN_DATA_HALF
,
y1
,
w
,
y2
,
z
,
b
,
nullptr
);
return
halo_y2
;
}
void
bottleneck_forward_out2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
std
::
cout
<<
std
::
fixed
;
// from _out1 method
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
auto
out1
=
outputs
[
0
];
at
::
Half
*
y1
=
out1
.
data_ptr
<
at
::
Half
>
();
// run
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
8
].
data_ptr
<
at
::
Half
>
();
auto
out2
=
outputs
[
1
];
at
::
Half
*
y2
=
out2
.
data_ptr
<
at
::
Half
>
();
printf
(
"forward_state.outdimA1 = {%d,%d,%d,%d}
\n
"
,
forward_state
.
outdimA1
[
0
],
forward_state
.
outdimA1
[
1
],
forward_state
.
outdimA1
[
2
],
forward_state
.
outdimA1
[
3
]);
printf
(
"forward_state.padA1 = {%d,%d}
\n
"
,
forward_state
.
padA1
[
0
],
forward_state
.
padA1
[
1
]);
printf
(
"forward_state.convstrideA = {%d,%d}
\n
"
,
forward_state
.
convstrideA
[
0
],
forward_state
.
convstrideA
[
1
]);
printf
(
"forward_state.dilationA = {%d,%d}
\n
"
,
forward_state
.
dilationA
[
0
],
forward_state
.
dilationA
[
1
]);
printf
(
"forward_state.filterdimA2 = {%d,%d,%d,%d}
\n
"
,
forward_state
.
filterdimA2
[
0
],
forward_state
.
filterdimA2
[
1
],
forward_state
.
filterdimA2
[
2
],
forward_state
.
filterdimA2
[
3
]);
printf
(
"forward_state.outdimA2 = {%d,%d,%d,%d}
\n
"
,
forward_state
.
outdimA2
[
0
],
forward_state
.
outdimA2
[
1
],
forward_state
.
outdimA2
[
2
],
forward_state
.
outdimA2
[
3
]);
run_conv_scale_bias_add_activation
(
forward_state
.
outdimA1
,
forward_state
.
padA1
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA2
,
forward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
y1
,
w
,
y2
,
z
,
b
,
nullptr
);
DEBUG_MSG
(
"[DEBUG] new relu2 : "
<<
out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
void
bottleneck_forward_rest
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
std
::
cout
<<
std
::
fixed
;
// from _out1 method
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
// create output of conv3
auto
out3
=
outputs
[
2
];
at
::
Half
*
y3
=
out3
.
data_ptr
<
at
::
Half
>
();
// create output of conv4 that may exist
auto
identity
=
at
::
empty_like
(
out3
);
at
::
Half
*
yi
=
identity
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
,
*
z
,
*
b
;
if
(
stride_1X1
!=
1
||
forward_state
.
filterdimA3
[
0
]
!=
forward_state
.
dimA
[
1
]){
w
=
inputs
[
10
].
data_ptr
<
at
::
Half
>
();
z
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
b
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias
(
forward_state
.
dimA
,
forward_state
.
padA
,
forward_state
.
convstride1X1
,
forward_state
.
dilationA
,
forward_state
.
filterdimA4
,
forward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
x
,
w
,
yi
,
z
,
b
);
DEBUG_MSG
(
"[DEBUG] new downsample : "
<<
identity
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
else
{
yi
=
x
;
}
auto
out2
=
outputs
[
1
];
at
::
Half
*
y2
=
out2
.
data_ptr
<
at
::
Half
>
();
w
=
inputs
[
3
].
data_ptr
<
at
::
Half
>
();
z
=
inputs
[
6
].
data_ptr
<
at
::
Half
>
();
b
=
inputs
[
9
].
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias_add_activation
(
forward_state
.
outdimA2
,
forward_state
.
padA
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA3
,
forward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
y2
,
w
,
y3
,
z
,
b
,
yi
);
DEBUG_MSG
(
"[DEBUG] new relu3 : "
<<
out3
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
namespace
{
struct
bottleneck_backward_state
{
int64_t
dimA
[
4
];
int64_t
filterdimA1
[
4
];
int64_t
filterdimA2
[
4
];
int64_t
filterdimA3
[
4
];
int64_t
filterdimA4
[
4
];
int
axis
[
4
];
int64_t
outdimA1
[
4
];
int64_t
outdimA2
[
4
];
int64_t
outdimA3
[
4
];
int64_t
padA
[
2
];
int64_t
padA1
[
2
];
int64_t
dilationA
[
2
];
int64_t
convstrideA
[
2
];
int64_t
convstride1X1
[
2
];
int64_t
outdim1
[
4
];
int64_t
outdim2
[
4
];
int64_t
outdim3
[
4
];
void
init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
// setup dimensions
dimA
[
0
]
=
dimA
[
1
]
=
dimA
[
2
]
=
dimA
[
3
]
=
0
;
filterdimA1
[
0
]
=
filterdimA1
[
1
]
=
filterdimA1
[
2
]
=
filterdimA1
[
3
]
=
0
;
filterdimA2
[
0
]
=
filterdimA2
[
1
]
=
filterdimA2
[
2
]
=
filterdimA2
[
3
]
=
0
;
filterdimA3
[
0
]
=
filterdimA3
[
1
]
=
filterdimA3
[
2
]
=
filterdimA3
[
3
]
=
0
;
filterdimA4
[
0
]
=
filterdimA4
[
1
]
=
filterdimA4
[
2
]
=
filterdimA4
[
3
]
=
0
;
// All dim calculation after this order of n,c,h,w
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
3
;
axis
[
2
]
=
1
;
axis
[
3
]
=
2
;
}
else
{
axis
[
0
]
=
0
;
axis
[
1
]
=
1
;
axis
[
2
]
=
2
;
axis
[
3
]
=
3
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
dimA
[
dim
]
=
inputs
[
0
].
size
(
axis
[
dim
]);
filterdimA1
[
dim
]
=
inputs
[
1
].
size
(
axis
[
dim
]);
filterdimA2
[
dim
]
=
inputs
[
2
].
size
(
axis
[
dim
]);
filterdimA3
[
dim
]
=
inputs
[
3
].
size
(
axis
[
dim
]);
}
if
(
stride_1X1
!=
1
||
filterdimA3
[
0
]
!=
dimA
[
1
])
{
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
filterdimA4
[
dim
]
=
inputs
[
14
].
size
(
axis
[
dim
]);
}
}
// output dim in n,c,h,w used by backend
outdimA1
[
0
]
=
outdimA1
[
1
]
=
outdimA1
[
2
]
=
outdimA1
[
3
]
=
0
;
outdimA2
[
0
]
=
outdimA2
[
1
]
=
outdimA2
[
2
]
=
outdimA2
[
3
]
=
0
;
outdimA3
[
0
]
=
outdimA3
[
1
]
=
outdimA3
[
2
]
=
outdimA3
[
3
]
=
0
;
// use these fixed value for test run
padA
[
0
]
=
0
;
padA
[
1
]
=
0
;
padA1
[
0
]
=
1
;
padA1
[
1
]
=
1
;
dilationA
[
0
]
=
1
;
dilationA
[
1
]
=
1
;
convstrideA
[
0
]
=
1
;
convstrideA
[
1
]
=
1
;
convstride1X1
[
0
]
=
stride_1X1
;
convstride1X1
[
1
]
=
stride_1X1
;
// compute output from pad/stride/dilation
outdimA1
[
0
]
=
dimA
[
0
];
outdimA1
[
1
]
=
filterdimA1
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA1
[
dim
+
2
]
=
getFwdConvOutputDim
(
dimA
[
dim
+
2
],
padA
[
dim
],
filterdimA1
[
dim
+
2
],
convstride1X1
[
dim
],
dilationA
[
dim
]);
}
outdimA2
[
0
]
=
outdimA1
[
0
];
outdimA2
[
1
]
=
filterdimA2
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA2
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA1
[
dim
+
2
],
padA1
[
dim
],
filterdimA2
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
outdimA3
[
0
]
=
outdimA2
[
0
];
outdimA3
[
1
]
=
filterdimA3
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA3
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA2
[
dim
+
2
],
padA
[
dim
],
filterdimA3
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
// Create output tensor in the correct shape in pytorch's view
outdim1
[
0
]
=
outdim1
[
1
]
=
outdim1
[
2
]
=
outdim1
[
3
]
=
0
;
outdim2
[
0
]
=
outdim2
[
1
]
=
outdim2
[
2
]
=
outdim2
[
3
]
=
0
;
outdim3
[
0
]
=
outdim3
[
1
]
=
outdim3
[
2
]
=
outdim3
[
3
]
=
0
;
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
2
;
axis
[
2
]
=
3
;
axis
[
3
]
=
1
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
}
}
};
bottleneck_backward_state
backward_state
;
}
std
::
vector
<
at
::
Tensor
>
bottleneck_backward_init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
std
::
cout
<<
std
::
fixed
;
backward_state
.
init
(
explicit_nhwc
,
stride_1X1
,
inputs
);
// create output vector
std
::
vector
<
at
::
Tensor
>
outputs
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
auto
grad_x
=
at
::
empty_like
(
inputs
[
0
]);
auto
wgrad1
=
at
::
empty_like
(
inputs
[
1
]);
auto
wgrad2
=
at
::
empty_like
(
inputs
[
2
]);
auto
wgrad3
=
at
::
empty_like
(
inputs
[
3
]);
outputs
.
push_back
(
grad_x
);
outputs
.
push_back
(
wgrad1
);
outputs
.
push_back
(
wgrad2
);
outputs
.
push_back
(
wgrad3
);
if
(
stride_1X1
!=
1
||
backward_state
.
filterdimA3
[
0
]
!=
backward_state
.
dimA
[
1
])
{
auto
wgrad4
=
at
::
empty_like
(
inputs
[
14
]);
outputs
.
push_back
(
wgrad4
);
}
return
outputs
;
}
at
::
Tensor
bottleneck_backward_grad_out2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dconv3+drelu2+dscale2
at
::
Half
*
conv_in
=
inputs
[
13
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dy3
=
inputs
[
10
].
data_ptr
<
at
::
Half
>
();
DEBUG_MSG
(
"[DEBUG] new dconv3 : "
<<
inputs
[
10
].
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
// wgrad
auto
wgrad3
=
outputs
[
3
];
at
::
Half
*
dw3
=
wgrad3
.
data_ptr
<
at
::
Half
>
();
run_dconv
(
backward_state
.
outdimA2
,
backward_state
.
padA
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA3
,
backward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
conv_in
,
dw3
,
dy3
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
// dgrad
auto
grad_out2
=
at
::
empty
(
backward_state
.
outdim2
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
3
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu2
=
inputs
[
13
].
data_ptr
<
at
::
Half
>
();
run_dconv_drelu_dscale
(
backward_state
.
outdimA2
,
backward_state
.
padA
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA3
,
backward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
dy2
,
w
,
dy3
,
z
,
relu2
);
// do halo exchange of dy2 here
DEBUG_MSG
(
"[DEBUG] new dconv2 : "
<<
grad_out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
return
grad_out2
;
}
void
bottleneck_backward_rest
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
at
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
();
// dconv2+drelu1+dscale1
at
::
Half
*
conv_in
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
// wgrad
auto
wgrad2
=
outputs
[
2
];
at
::
Half
*
dw2
=
wgrad2
.
data_ptr
<
at
::
Half
>
();
printf
(
"outdimA1 = (%d,%d,%d,%d)
\n
"
,
backward_state
.
outdimA1
[
0
],
backward_state
.
outdimA1
[
1
],
backward_state
.
outdimA1
[
2
],
backward_state
.
outdimA1
[
3
]);
run_dconv
(
backward_state
.
outdimA1
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
backward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
conv_in
,
dw2
,
dy2
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
// dgrad
auto
grad_out1
=
at
::
empty
(
backward_state
.
outdim1
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
dy1
=
grad_out1
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu1
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
// fused dgrad
run_dconv_drelu_dscale
(
backward_state
.
outdimA1
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
backward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
dy1
,
w
,
dy2
,
z
,
relu1
);
/*
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if (stride_1X1 != 1){
// dgrad
run_dconv(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// mul fused mask
grad_out1.mul_(inputs[15]);
}
else {
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
run_dconv_drelu_dscale(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1);
}
*/
DEBUG_MSG
(
"[DEBUG] new dconv1 : "
<<
grad_out1
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
// create grads of conv4 that may exist
auto
grad_x_conv4
=
at
::
empty_like
(
inputs
[
0
]);
at
::
Half
*
dx_conv4
=
grad_x_conv4
.
data_ptr
<
at
::
Half
>
();
at
::
Tensor
wgrad4
;
// x used for dconv1 and dconv4 wgrad
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
if
(
stride_1X1
!=
1
||
backward_state
.
filterdimA3
[
0
]
!=
backward_state
.
dimA
[
1
]){
w
=
inputs
[
14
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dy_conv4
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
if
(
requires_grad
)
{
run_dconv
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA4
,
backward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
dx_conv4
,
w
,
dy_conv4
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
);
// we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx
// DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item<float>());
}
// wgrad
wgrad4
=
outputs
[
4
];
at
::
Half
*
dw4
=
wgrad4
.
data_ptr
<
at
::
Half
>
();
run_dconv
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA4
,
backward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
x
,
dw4
,
dy_conv4
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
}
else
{
// if there is no downsample, dx_conv4 is fork of drelu3
dx_conv4
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
}
// dconv1+add
// wgrad
auto
wgrad1
=
outputs
[
1
];
at
::
Half
*
dw1
=
wgrad1
.
data_ptr
<
at
::
Half
>
();
run_dconv
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA1
,
backward_state
.
outdimA1
,
CUDNN_DATA_HALF
,
x
,
dw1
,
dy1
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
// dgrad
w
=
inputs
[
1
].
data_ptr
<
at
::
Half
>
();
auto
grad_x
=
outputs
[
0
];
at
::
Half
*
dx
=
grad_x
.
data_ptr
<
at
::
Half
>
();
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if
(
requires_grad
){
if
(
stride_1X1
!=
1
){
run_dconv
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA1
,
backward_state
.
outdimA1
,
CUDNN_DATA_HALF
,
dx
,
w
,
dy1
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
);
// add 2 together
grad_x
.
add_
(
grad_x_conv4
);
}
else
{
run_dconv_add
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA1
,
backward_state
.
outdimA1
,
CUDNN_DATA_HALF
,
dx
,
w
,
dy1
,
dx_conv4
);
}
}
DEBUG_MSG
(
"[DEBUG] new dx : "
<<
grad_x
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new wgrad1 : "
<<
wgrad1
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new wgrad2 : "
<<
wgrad2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new wgrad3 : "
<<
wgrad3
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
if
(
stride_1X1
!=
1
||
backward_state
.
filterdimA3
[
0
]
!=
backward_state
.
dimA
[
1
])
{
DEBUG_MSG
(
"[DEBUG] new wgrad4 : "
<<
wgrad4
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
bottleneck_forward
,
"Bottleneck block forward"
);
m
.
def
(
"backward"
,
&
bottleneck_backward
,
"Bottleneck block backward"
);
m
.
def
(
"forward_init"
,
&
bottleneck_forward_init
,
"Bottleneck block init"
);
m
.
def
(
"forward_out1"
,
&
bottleneck_forward_out1
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2"
,
&
bottleneck_forward_out2
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_halo"
,
&
bottleneck_forward_out2_halo
,
"Bottleneck block forward"
);
m
.
def
(
"forward_rest"
,
&
bottleneck_forward_rest
,
"Bottleneck block forward"
);
m
.
def
(
"backward_init"
,
&
bottleneck_backward_init
,
"Bottleneck block backward init"
);
m
.
def
(
"backward_grad_out2"
,
&
bottleneck_backward_grad_out2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_rest"
,
&
bottleneck_backward_rest
,
"Bottleneck block backward"
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment