Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
2f164a2a
"examples/vscode:/vscode.git/clone" did not exist on "72f9364cc6aa28380c5453476e1cc25e22f4f869"
Commit
2f164a2a
authored
Aug 31, 2021
by
Thor Johnsen
Browse files
First release
parent
d6b5ae5d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
953 additions
and
1 deletion
+953
-1
apex/contrib/bottleneck/__init__.py
apex/contrib/bottleneck/__init__.py
+1
-1
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+233
-0
apex/contrib/csrc/bottleneck/bottleneck.cpp
apex/contrib/csrc/bottleneck/bottleneck.cpp
+719
-0
No files found.
apex/contrib/bottleneck/__init__.py
View file @
2f164a2a
from
.bottleneck
import
Bottleneck
from
.bottleneck
import
Bottleneck
,
SpatialBottleneck
apex/contrib/bottleneck/bottleneck.py
View file @
2f164a2a
import
torch
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
from
torch
import
nn
import
fast_bottleneck
import
fast_bottleneck
...
@@ -212,3 +213,235 @@ class Bottleneck(torch.nn.Module):
...
@@ -212,3 +213,235 @@ class Bottleneck(torch.nn.Module):
out
=
self
.
relu
(
out
)
out
=
self
.
relu
(
out
)
return
out
return
out
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
spatial_group_size
,
local_rank
,
comm
,
stream1
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
# TODO: clean up order of tensors
args
=
[
x
,
*
conv
[
0
:
3
],
*
scale
[
0
:
3
],
*
bias
[
0
:
3
]]
ctx
.
downsample
=
len
(
conv
)
>
3
if
ctx
.
downsample
:
args
.
append
(
conv
[
3
])
args
.
append
(
scale
[
3
])
args
.
append
(
bias
[
3
])
# weight buffers are always in nhwc while shape can be nhwc or channels_last
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
outputs
=
fast_bottleneck
.
forward_init
(
nhwc
,
stride_1x1
,
args
)
fast_bottleneck
.
forward_out1
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
# do halo exchange for outputs[0] (out1)
if
spatial_group_size
>
1
:
out1
=
outputs
[
0
]
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
padded_out1
=
torch
.
empty
((
N
,
Hs
+
2
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
padded_out1
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream1
):
# copy halos to send buffer
send_halos
=
torch
.
empty
((
N
,
2
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
send_halos
[:,:
1
,:,:].
copy_
(
out1
[:,:
1
,:,:])
send_halos
[:,
1
:,:,:].
copy_
(
out1
[:,
Hs
-
1
:,:,:])
all_halos
=
torch
.
empty
((
N
,
2
*
spatial_group_size
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
:(
i
+
1
)
*
2
,:,:]
for
i
in
range
(
spatial_group_size
)]
dist
.
all_gather
(
all_halos
,
send_halos
)
padded_out1_top_halo
=
padded_out1
[:,:
1
,:,:]
if
local_rank
>
0
:
top_halo
=
all_halos
[
local_rank
-
1
][:,
1
:,:,:]
padded_out1_top_halo
.
copy_
(
top_halo
)
fat_top_halo
=
padded_out1
[:,:
3
,:,:]
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_top_halo
,
args
)
else
:
padded_out1_top_halo
.
zero_
()
padded_out1_btm_halo
=
padded_out1
[:,
Hs
+
1
:,:,:]
if
local_rank
<
spatial_group_size
-
1
:
btm_halo
=
all_halos
[
local_rank
+
1
][:,:
1
,:,:]
padded_out1_btm_halo
.
copy_
(
btm_halo
)
fat_btm_halo
=
padded_out1
[:,
Hs
-
1
:,:,:]
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_btm_halo
,
args
)
else
:
padded_out1_btm_halo
.
zero_
()
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
out2
=
outputs
[
1
]
if
local_rank
>
0
:
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
if
local_rank
<
spatial_group_size
-
1
:
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
if
spatial_group_size
>
1
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
padded_out1
]))
else
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
# save relu outputs for drelu
ctx
.
nhwc
=
nhwc
ctx
.
stride_1x1
=
stride_1x1
ctx
.
spatial_group_size
=
spatial_group_size
ctx
.
local_rank
=
local_rank
ctx
.
comm
=
comm
ctx
.
stream1
=
stream1
return
outputs
[
2
]
# backward relu is not exposed, MUL with mask used now
# only support dgrad
@
staticmethod
def
backward
(
ctx
,
grad_o
):
if
ctx
.
spatial_group_size
>
1
:
outputs
=
ctx
.
saved_tensors
[
-
4
:
-
1
]
else
:
outputs
=
ctx
.
saved_tensors
[
-
3
:]
if
ctx
.
downsample
:
grad_conv3
,
grad_conv4
=
drelu_dscale2
(
grad_o
,
outputs
[
2
],
ctx
.
saved_tensors
[
6
],
ctx
.
saved_tensors
[
11
])
else
:
grad_conv3
,
grad_conv4
=
drelu_dscale1
(
grad_o
,
outputs
[
2
],
ctx
.
saved_tensors
[
6
])
# create input vector for backward
t_list
=
[
*
ctx
.
saved_tensors
[
0
:
10
]]
t_list
.
append
(
grad_conv3
)
t_list
.
append
(
grad_conv4
)
# outputs used for wgrad and generating drelu mask
t_list
.
append
(
outputs
[
0
])
t_list
.
append
(
outputs
[
1
])
# in case there is downsample
if
ctx
.
downsample
:
t_list
.
append
(
ctx
.
saved_tensors
[
10
])
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
# do halo exchange of grad_out2 here
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
spatial_bottleneck_function
=
SpatialBottleneckFunction
.
apply
class
SpatialBottleneck
(
torch
.
nn
.
Module
):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
# here we put it at 1x1
def
__init__
(
self
,
in_channels
,
bottleneck_channels
,
out_channels
,
stride
=
1
,
groups
=
1
,
dilation
=
1
,
norm_func
=
None
,
use_cudnn
=
False
,
explicit_nhwc
=
False
,
spatial_group_size
=
1
):
super
(
SpatialBottleneck
,
self
).
__init__
()
if
groups
!=
1
:
raise
RuntimeError
(
'Only support groups == 1'
)
if
dilation
!=
1
:
raise
RuntimeError
(
'Only support dilation == 1'
)
if
norm_func
==
None
:
norm_func
=
FrozenBatchNorm2d
else
:
raise
RuntimeError
(
'Only support frozen BN now.'
)
if
stride
!=
1
or
in_channels
!=
out_channels
:
self
.
downsample
=
nn
.
Sequential
(
conv1x1
(
in_channels
,
out_channels
,
stride
),
norm_func
(
out_channels
),
)
else
:
self
.
downsample
=
None
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self
.
conv1
=
conv1x1
(
in_channels
,
bottleneck_channels
,
stride
)
self
.
conv2
=
conv3x3
(
bottleneck_channels
,
bottleneck_channels
)
self
.
conv3
=
conv1x1
(
bottleneck_channels
,
out_channels
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
stride
=
stride
self
.
bn1
=
norm_func
(
bottleneck_channels
)
self
.
bn2
=
norm_func
(
bottleneck_channels
)
self
.
bn3
=
norm_func
(
out_channels
)
self
.
use_cudnn
=
use_cudnn
# setup conv weights
self
.
w_conv
=
[
self
.
conv1
.
weight
,
self
.
conv2
.
weight
,
self
.
conv3
.
weight
]
if
self
.
downsample
is
not
None
:
self
.
w_conv
.
append
(
self
.
downsample
[
0
].
weight
)
# init weight in nchw format before possible transpose
for
w
in
self
.
w_conv
:
kaiming_uniform_
(
w
,
a
=
1
)
# TODO: prevent unsupported case usage
# support cases
# native cudnn
# normal yes no
# channel_last yes yes
# explicit_nhwc no yes
self
.
explicit_nhwc
=
explicit_nhwc
if
self
.
explicit_nhwc
:
for
p
in
self
.
parameters
():
with
torch
.
no_grad
():
p
.
data
=
p
.
data
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
# spatial communicator
self
.
spatial_group_size
=
spatial_group_size
if
spatial_group_size
>
1
:
world_size
=
dist
.
get_world_size
()
num_groups
=
world_size
//
spatial_group_size
assert
(
num_groups
*
spatial_group_size
==
world_size
),
"torch.distributed.get_world_size() must be multiple of group_size"
rank
=
dist
.
get_rank
()
self
.
local_rank
=
rank
%
spatial_group_size
for
group
in
range
(
num_groups
):
ranks
=
list
(
range
(
group
*
spatial_group_size
,(
group
+
1
)
*
spatial_group_size
))
comm
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
rank
in
ranks
:
self
.
communicator
=
comm
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
spatial_args
=
self
.
spatial_group_size
,
self
.
local_rank
,
self
.
communicator
,
self
.
stream1
else
:
self
.
spatial_args
=
1
,
0
,
None
,
None
return
def
forward
(
self
,
x
):
if
self
.
use_cudnn
:
# calculate scale/bias from registered buffers
# TODO: make this better
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
s2
,
b2
=
self
.
bn2
.
get_scale_bias
(
self
.
explicit_nhwc
)
s3
,
b3
=
self
.
bn3
.
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
=
[
s1
,
s2
,
s3
]
w_bias
=
[
b1
,
b2
,
b3
]
if
self
.
downsample
is
not
None
:
s4
,
b4
=
self
.
downsample
[
1
].
get_scale_bias
(
self
.
explicit_nhwc
)
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
return
out
if
self
.
explicit_nhwc
:
raise
RuntimeError
(
'explicit nhwc with native ops is not supported.'
)
# fallback to native ops
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
apex/contrib/csrc/bottleneck/bottleneck.cpp
View file @
2f164a2a
...
@@ -1606,7 +1606,726 @@ std::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1,
...
@@ -1606,7 +1606,726 @@ std::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1,
return
outputs
;
return
outputs
;
}
}
namespace
{
struct
bottleneck_forward_status
{
int64_t
dimA
[
4
];
int64_t
filterdimA1
[
4
];
int64_t
filterdimA2
[
4
];
int64_t
filterdimA3
[
4
];
int64_t
filterdimA4
[
4
];
int
axis
[
4
];
int64_t
outdimA0
[
4
];
int64_t
outdimA1
[
4
];
int64_t
outdimA2
[
4
];
int64_t
outdimA3
[
4
];
int64_t
outdimA4
[
4
];
int64_t
padA
[
2
];
int64_t
padA1
[
2
];
int64_t
padA2
[
2
];
// halo padding
int64_t
dilationA
[
2
];
int64_t
convstrideA
[
2
];
int64_t
convstride1X1
[
2
];
int64_t
outdim0
[
4
];
// halo input shape
int64_t
outdim1
[
4
];
int64_t
outdim2
[
4
];
int64_t
outdim3
[
4
];
int64_t
outdim4
[
4
];
// halo output shape
void
init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
dimA
[
0
]
=
dimA
[
1
]
=
dimA
[
2
]
=
dimA
[
3
]
=
0
;
filterdimA1
[
0
]
=
filterdimA1
[
1
]
=
filterdimA1
[
2
]
=
filterdimA1
[
3
]
=
0
;
filterdimA2
[
0
]
=
filterdimA2
[
1
]
=
filterdimA2
[
2
]
=
filterdimA2
[
3
]
=
0
;
filterdimA3
[
0
]
=
filterdimA3
[
1
]
=
filterdimA3
[
2
]
=
filterdimA3
[
3
]
=
0
;
filterdimA4
[
0
]
=
filterdimA4
[
1
]
=
filterdimA4
[
2
]
=
filterdimA4
[
3
]
=
0
;
// All dim calculation after this order of n,c,h,w
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
3
;
axis
[
2
]
=
1
;
axis
[
3
]
=
2
;
}
else
{
axis
[
0
]
=
0
;
axis
[
1
]
=
1
;
axis
[
2
]
=
2
;
axis
[
3
]
=
3
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
dimA
[
dim
]
=
inputs
[
0
].
size
(
axis
[
dim
]);
filterdimA1
[
dim
]
=
inputs
[
1
].
size
(
axis
[
dim
]);
filterdimA2
[
dim
]
=
inputs
[
2
].
size
(
axis
[
dim
]);
filterdimA3
[
dim
]
=
inputs
[
3
].
size
(
axis
[
dim
]);
}
if
(
stride_1X1
!=
1
||
filterdimA3
[
0
]
!=
dimA
[
1
])
{
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
filterdimA4
[
dim
]
=
inputs
[
10
].
size
(
axis
[
dim
]);
}
}
// output dim in n,c,h,w used by backend
outdimA0
[
0
]
=
outdimA0
[
1
]
=
outdimA0
[
2
]
=
outdimA0
[
3
]
=
0
;
outdimA1
[
0
]
=
outdimA1
[
1
]
=
outdimA1
[
2
]
=
outdimA1
[
3
]
=
0
;
outdimA2
[
0
]
=
outdimA2
[
1
]
=
outdimA2
[
2
]
=
outdimA2
[
3
]
=
0
;
outdimA3
[
0
]
=
outdimA3
[
1
]
=
outdimA3
[
2
]
=
outdimA3
[
3
]
=
0
;
outdimA4
[
0
]
=
outdimA4
[
1
]
=
outdimA4
[
2
]
=
outdimA4
[
3
]
=
0
;
// use these fixed value for test run
padA
[
0
]
=
0
;
padA
[
1
]
=
0
;
padA1
[
0
]
=
1
;
padA1
[
1
]
=
1
;
padA2
[
0
]
=
0
;
padA2
[
1
]
=
1
;
dilationA
[
0
]
=
1
;
dilationA
[
1
]
=
1
;
convstrideA
[
0
]
=
1
;
convstrideA
[
1
]
=
1
;
convstride1X1
[
0
]
=
stride_1X1
;
convstride1X1
[
1
]
=
stride_1X1
;
// compute output from pad/stride/dilation
outdimA1
[
0
]
=
dimA
[
0
];
outdimA1
[
1
]
=
filterdimA1
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA1
[
dim
+
2
]
=
getFwdConvOutputDim
(
dimA
[
dim
+
2
],
padA
[
dim
],
filterdimA1
[
dim
+
2
],
convstride1X1
[
dim
],
dilationA
[
dim
]);
}
outdimA2
[
0
]
=
outdimA1
[
0
];
outdimA2
[
1
]
=
filterdimA2
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA2
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA1
[
dim
+
2
],
padA1
[
dim
],
filterdimA2
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
if
(
dim
==
2
)
{
outdimA0
[
dim
]
=
3
;
outdimA4
[
dim
]
=
1
;
}
else
{
outdimA0
[
dim
]
=
outdimA1
[
dim
];
outdimA4
[
dim
]
=
outdimA2
[
dim
];
}
}
outdimA3
[
0
]
=
outdimA2
[
0
];
outdimA3
[
1
]
=
filterdimA3
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA3
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA2
[
dim
+
2
],
padA
[
dim
],
filterdimA3
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
// Create output tensor in the correct shape in pytorch's view
outdim1
[
0
]
=
outdim1
[
1
]
=
outdim1
[
2
]
=
outdim1
[
3
]
=
0
;
outdim2
[
0
]
=
outdim2
[
1
]
=
outdim2
[
2
]
=
outdim2
[
3
]
=
0
;
outdim3
[
0
]
=
outdim3
[
1
]
=
outdim3
[
2
]
=
outdim3
[
3
]
=
0
;
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
2
;
axis
[
2
]
=
3
;
axis
[
3
]
=
1
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
outdim0
[
dim
]
=
outdimA0
[
axis
[
dim
]];
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
outdim4
[
dim
]
=
outdimA4
[
axis
[
dim
]];
}
}
};
bottleneck_forward_status
forward_state
;
}
// end of anonymous namespace
std
::
vector
<
at
::
Tensor
>
bottleneck_forward_init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
// NB! Bottleneck_forward and bottleneck_backward are NOT thread safe method.
// NB! We use a global object to store state.
forward_state
.
init
(
explicit_nhwc
,
stride_1X1
,
inputs
);
// create output vector
std
::
vector
<
at
::
Tensor
>
outputs
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
printf
(
"outdim1 = (%d,%d,%d,%d)
\n
"
,
forward_state
.
outdim1
[
0
],
forward_state
.
outdim1
[
1
],
forward_state
.
outdim1
[
2
],
forward_state
.
outdim1
[
3
]);
auto
out1
=
at
::
empty
(
forward_state
.
outdim1
,
inputs
[
0
].
type
(),
output_format
);
auto
out2
=
at
::
empty
(
forward_state
.
outdim2
,
inputs
[
0
].
type
(),
output_format
);
auto
out3
=
at
::
empty
(
forward_state
.
outdim3
,
inputs
[
0
].
type
(),
output_format
);
outputs
.
push_back
(
out1
);
outputs
.
push_back
(
out2
);
outputs
.
push_back
(
out3
);
return
outputs
;
}
// inputs contains x,w,z,b,(i)
void
bottleneck_forward_out1
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
std
::
cout
<<
std
::
fixed
;
// run
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
1
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
7
].
data_ptr
<
at
::
Half
>
();
auto
out1
=
outputs
[
0
];
at
::
Half
*
y1
=
out1
.
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias_add_activation
(
forward_state
.
dimA
,
forward_state
.
padA
,
forward_state
.
convstride1X1
,
forward_state
.
dilationA
,
forward_state
.
filterdimA1
,
forward_state
.
outdimA1
,
CUDNN_DATA_HALF
,
x
,
w
,
y1
,
z
,
b
,
nullptr
);
DEBUG_MSG
(
"[DEBUG] new relu1 : "
<<
out1
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
// computes halo (top or bottom) from fat halo input.
// fat halo input is 3 pixels wide in H.
at
::
Tensor
bottleneck_forward_out2_halo
(
bool
explicit_nhwc
,
at
::
Tensor
fat_halo_y1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// run
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
8
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
y1
=
fat_halo_y1
.
data_ptr
<
at
::
Half
>
();
auto
halo_y2
=
at
::
empty
(
forward_state
.
outdim4
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
y2
=
halo_y2
.
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias_add_activation
(
forward_state
.
outdimA0
,
forward_state
.
padA2
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA2
,
forward_state
.
outdimA4
,
CUDNN_DATA_HALF
,
y1
,
w
,
y2
,
z
,
b
,
nullptr
);
return
halo_y2
;
}
void
bottleneck_forward_out2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
std
::
cout
<<
std
::
fixed
;
// from _out1 method
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
auto
out1
=
outputs
[
0
];
at
::
Half
*
y1
=
out1
.
data_ptr
<
at
::
Half
>
();
// run
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
8
].
data_ptr
<
at
::
Half
>
();
auto
out2
=
outputs
[
1
];
at
::
Half
*
y2
=
out2
.
data_ptr
<
at
::
Half
>
();
printf
(
"forward_state.outdimA1 = {%d,%d,%d,%d}
\n
"
,
forward_state
.
outdimA1
[
0
],
forward_state
.
outdimA1
[
1
],
forward_state
.
outdimA1
[
2
],
forward_state
.
outdimA1
[
3
]);
printf
(
"forward_state.padA1 = {%d,%d}
\n
"
,
forward_state
.
padA1
[
0
],
forward_state
.
padA1
[
1
]);
printf
(
"forward_state.convstrideA = {%d,%d}
\n
"
,
forward_state
.
convstrideA
[
0
],
forward_state
.
convstrideA
[
1
]);
printf
(
"forward_state.dilationA = {%d,%d}
\n
"
,
forward_state
.
dilationA
[
0
],
forward_state
.
dilationA
[
1
]);
printf
(
"forward_state.filterdimA2 = {%d,%d,%d,%d}
\n
"
,
forward_state
.
filterdimA2
[
0
],
forward_state
.
filterdimA2
[
1
],
forward_state
.
filterdimA2
[
2
],
forward_state
.
filterdimA2
[
3
]);
printf
(
"forward_state.outdimA2 = {%d,%d,%d,%d}
\n
"
,
forward_state
.
outdimA2
[
0
],
forward_state
.
outdimA2
[
1
],
forward_state
.
outdimA2
[
2
],
forward_state
.
outdimA2
[
3
]);
run_conv_scale_bias_add_activation
(
forward_state
.
outdimA1
,
forward_state
.
padA1
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA2
,
forward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
y1
,
w
,
y2
,
z
,
b
,
nullptr
);
DEBUG_MSG
(
"[DEBUG] new relu2 : "
<<
out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
void
bottleneck_forward_rest
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
std
::
cout
<<
std
::
fixed
;
// from _out1 method
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
// create output of conv3
auto
out3
=
outputs
[
2
];
at
::
Half
*
y3
=
out3
.
data_ptr
<
at
::
Half
>
();
// create output of conv4 that may exist
auto
identity
=
at
::
empty_like
(
out3
);
at
::
Half
*
yi
=
identity
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
,
*
z
,
*
b
;
if
(
stride_1X1
!=
1
||
forward_state
.
filterdimA3
[
0
]
!=
forward_state
.
dimA
[
1
]){
w
=
inputs
[
10
].
data_ptr
<
at
::
Half
>
();
z
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
b
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias
(
forward_state
.
dimA
,
forward_state
.
padA
,
forward_state
.
convstride1X1
,
forward_state
.
dilationA
,
forward_state
.
filterdimA4
,
forward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
x
,
w
,
yi
,
z
,
b
);
DEBUG_MSG
(
"[DEBUG] new downsample : "
<<
identity
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
else
{
yi
=
x
;
}
auto
out2
=
outputs
[
1
];
at
::
Half
*
y2
=
out2
.
data_ptr
<
at
::
Half
>
();
w
=
inputs
[
3
].
data_ptr
<
at
::
Half
>
();
z
=
inputs
[
6
].
data_ptr
<
at
::
Half
>
();
b
=
inputs
[
9
].
data_ptr
<
at
::
Half
>
();
run_conv_scale_bias_add_activation
(
forward_state
.
outdimA2
,
forward_state
.
padA
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA3
,
forward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
y2
,
w
,
y3
,
z
,
b
,
yi
);
DEBUG_MSG
(
"[DEBUG] new relu3 : "
<<
out3
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
namespace
{
struct
bottleneck_backward_state
{
int64_t
dimA
[
4
];
int64_t
filterdimA1
[
4
];
int64_t
filterdimA2
[
4
];
int64_t
filterdimA3
[
4
];
int64_t
filterdimA4
[
4
];
int
axis
[
4
];
int64_t
outdimA1
[
4
];
int64_t
outdimA2
[
4
];
int64_t
outdimA3
[
4
];
int64_t
padA
[
2
];
int64_t
padA1
[
2
];
int64_t
dilationA
[
2
];
int64_t
convstrideA
[
2
];
int64_t
convstride1X1
[
2
];
int64_t
outdim1
[
4
];
int64_t
outdim2
[
4
];
int64_t
outdim3
[
4
];
void
init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
// setup dimensions
dimA
[
0
]
=
dimA
[
1
]
=
dimA
[
2
]
=
dimA
[
3
]
=
0
;
filterdimA1
[
0
]
=
filterdimA1
[
1
]
=
filterdimA1
[
2
]
=
filterdimA1
[
3
]
=
0
;
filterdimA2
[
0
]
=
filterdimA2
[
1
]
=
filterdimA2
[
2
]
=
filterdimA2
[
3
]
=
0
;
filterdimA3
[
0
]
=
filterdimA3
[
1
]
=
filterdimA3
[
2
]
=
filterdimA3
[
3
]
=
0
;
filterdimA4
[
0
]
=
filterdimA4
[
1
]
=
filterdimA4
[
2
]
=
filterdimA4
[
3
]
=
0
;
// All dim calculation after this order of n,c,h,w
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
3
;
axis
[
2
]
=
1
;
axis
[
3
]
=
2
;
}
else
{
axis
[
0
]
=
0
;
axis
[
1
]
=
1
;
axis
[
2
]
=
2
;
axis
[
3
]
=
3
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
dimA
[
dim
]
=
inputs
[
0
].
size
(
axis
[
dim
]);
filterdimA1
[
dim
]
=
inputs
[
1
].
size
(
axis
[
dim
]);
filterdimA2
[
dim
]
=
inputs
[
2
].
size
(
axis
[
dim
]);
filterdimA3
[
dim
]
=
inputs
[
3
].
size
(
axis
[
dim
]);
}
if
(
stride_1X1
!=
1
||
filterdimA3
[
0
]
!=
dimA
[
1
])
{
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
filterdimA4
[
dim
]
=
inputs
[
14
].
size
(
axis
[
dim
]);
}
}
// output dim in n,c,h,w used by backend
outdimA1
[
0
]
=
outdimA1
[
1
]
=
outdimA1
[
2
]
=
outdimA1
[
3
]
=
0
;
outdimA2
[
0
]
=
outdimA2
[
1
]
=
outdimA2
[
2
]
=
outdimA2
[
3
]
=
0
;
outdimA3
[
0
]
=
outdimA3
[
1
]
=
outdimA3
[
2
]
=
outdimA3
[
3
]
=
0
;
// use these fixed value for test run
padA
[
0
]
=
0
;
padA
[
1
]
=
0
;
padA1
[
0
]
=
1
;
padA1
[
1
]
=
1
;
dilationA
[
0
]
=
1
;
dilationA
[
1
]
=
1
;
convstrideA
[
0
]
=
1
;
convstrideA
[
1
]
=
1
;
convstride1X1
[
0
]
=
stride_1X1
;
convstride1X1
[
1
]
=
stride_1X1
;
// compute output from pad/stride/dilation
outdimA1
[
0
]
=
dimA
[
0
];
outdimA1
[
1
]
=
filterdimA1
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA1
[
dim
+
2
]
=
getFwdConvOutputDim
(
dimA
[
dim
+
2
],
padA
[
dim
],
filterdimA1
[
dim
+
2
],
convstride1X1
[
dim
],
dilationA
[
dim
]);
}
outdimA2
[
0
]
=
outdimA1
[
0
];
outdimA2
[
1
]
=
filterdimA2
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA2
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA1
[
dim
+
2
],
padA1
[
dim
],
filterdimA2
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
outdimA3
[
0
]
=
outdimA2
[
0
];
outdimA3
[
1
]
=
filterdimA3
[
0
];
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA3
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA2
[
dim
+
2
],
padA
[
dim
],
filterdimA3
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
// Create output tensor in the correct shape in pytorch's view
outdim1
[
0
]
=
outdim1
[
1
]
=
outdim1
[
2
]
=
outdim1
[
3
]
=
0
;
outdim2
[
0
]
=
outdim2
[
1
]
=
outdim2
[
2
]
=
outdim2
[
3
]
=
0
;
outdim3
[
0
]
=
outdim3
[
1
]
=
outdim3
[
2
]
=
outdim3
[
3
]
=
0
;
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
2
;
axis
[
2
]
=
3
;
axis
[
3
]
=
1
;
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
}
}
};
bottleneck_backward_state
backward_state
;
}
std
::
vector
<
at
::
Tensor
>
bottleneck_backward_init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
std
::
cout
<<
std
::
fixed
;
backward_state
.
init
(
explicit_nhwc
,
stride_1X1
,
inputs
);
// create output vector
std
::
vector
<
at
::
Tensor
>
outputs
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
auto
grad_x
=
at
::
empty_like
(
inputs
[
0
]);
auto
wgrad1
=
at
::
empty_like
(
inputs
[
1
]);
auto
wgrad2
=
at
::
empty_like
(
inputs
[
2
]);
auto
wgrad3
=
at
::
empty_like
(
inputs
[
3
]);
outputs
.
push_back
(
grad_x
);
outputs
.
push_back
(
wgrad1
);
outputs
.
push_back
(
wgrad2
);
outputs
.
push_back
(
wgrad3
);
if
(
stride_1X1
!=
1
||
backward_state
.
filterdimA3
[
0
]
!=
backward_state
.
dimA
[
1
])
{
auto
wgrad4
=
at
::
empty_like
(
inputs
[
14
]);
outputs
.
push_back
(
wgrad4
);
}
return
outputs
;
}
at
::
Tensor
bottleneck_backward_grad_out2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dconv3+drelu2+dscale2
at
::
Half
*
conv_in
=
inputs
[
13
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dy3
=
inputs
[
10
].
data_ptr
<
at
::
Half
>
();
DEBUG_MSG
(
"[DEBUG] new dconv3 : "
<<
inputs
[
10
].
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
// wgrad
auto
wgrad3
=
outputs
[
3
];
at
::
Half
*
dw3
=
wgrad3
.
data_ptr
<
at
::
Half
>
();
run_dconv
(
backward_state
.
outdimA2
,
backward_state
.
padA
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA3
,
backward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
conv_in
,
dw3
,
dy3
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
// dgrad
auto
grad_out2
=
at
::
empty
(
backward_state
.
outdim2
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
3
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu2
=
inputs
[
13
].
data_ptr
<
at
::
Half
>
();
run_dconv_drelu_dscale
(
backward_state
.
outdimA2
,
backward_state
.
padA
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA3
,
backward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
dy2
,
w
,
dy3
,
z
,
relu2
);
// do halo exchange of dy2 here
DEBUG_MSG
(
"[DEBUG] new dconv2 : "
<<
grad_out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
return
grad_out2
;
}
void
bottleneck_backward_rest
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
at
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
();
// dconv2+drelu1+dscale1
at
::
Half
*
conv_in
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
// wgrad
auto
wgrad2
=
outputs
[
2
];
at
::
Half
*
dw2
=
wgrad2
.
data_ptr
<
at
::
Half
>
();
printf
(
"outdimA1 = (%d,%d,%d,%d)
\n
"
,
backward_state
.
outdimA1
[
0
],
backward_state
.
outdimA1
[
1
],
backward_state
.
outdimA1
[
2
],
backward_state
.
outdimA1
[
3
]);
run_dconv
(
backward_state
.
outdimA1
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
backward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
conv_in
,
dw2
,
dy2
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
// dgrad
auto
grad_out1
=
at
::
empty
(
backward_state
.
outdim1
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
dy1
=
grad_out1
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu1
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
// fused dgrad
run_dconv_drelu_dscale
(
backward_state
.
outdimA1
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
backward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
dy1
,
w
,
dy2
,
z
,
relu1
);
/*
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if (stride_1X1 != 1){
// dgrad
run_dconv(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// mul fused mask
grad_out1.mul_(inputs[15]);
}
else {
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
run_dconv_drelu_dscale(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1);
}
*/
DEBUG_MSG
(
"[DEBUG] new dconv1 : "
<<
grad_out1
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
// create grads of conv4 that may exist
auto
grad_x_conv4
=
at
::
empty_like
(
inputs
[
0
]);
at
::
Half
*
dx_conv4
=
grad_x_conv4
.
data_ptr
<
at
::
Half
>
();
at
::
Tensor
wgrad4
;
// x used for dconv1 and dconv4 wgrad
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
if
(
stride_1X1
!=
1
||
backward_state
.
filterdimA3
[
0
]
!=
backward_state
.
dimA
[
1
]){
w
=
inputs
[
14
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dy_conv4
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
if
(
requires_grad
)
{
run_dconv
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA4
,
backward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
dx_conv4
,
w
,
dy_conv4
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
);
// we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx
// DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item<float>());
}
// wgrad
wgrad4
=
outputs
[
4
];
at
::
Half
*
dw4
=
wgrad4
.
data_ptr
<
at
::
Half
>
();
run_dconv
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA4
,
backward_state
.
outdimA3
,
CUDNN_DATA_HALF
,
x
,
dw4
,
dy_conv4
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
}
else
{
// if there is no downsample, dx_conv4 is fork of drelu3
dx_conv4
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
}
// dconv1+add
// wgrad
auto
wgrad1
=
outputs
[
1
];
at
::
Half
*
dw1
=
wgrad1
.
data_ptr
<
at
::
Half
>
();
run_dconv
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA1
,
backward_state
.
outdimA1
,
CUDNN_DATA_HALF
,
x
,
dw1
,
dy1
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
// dgrad
w
=
inputs
[
1
].
data_ptr
<
at
::
Half
>
();
auto
grad_x
=
outputs
[
0
];
at
::
Half
*
dx
=
grad_x
.
data_ptr
<
at
::
Half
>
();
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if
(
requires_grad
){
if
(
stride_1X1
!=
1
){
run_dconv
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA1
,
backward_state
.
outdimA1
,
CUDNN_DATA_HALF
,
dx
,
w
,
dy1
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
);
// add 2 together
grad_x
.
add_
(
grad_x_conv4
);
}
else
{
run_dconv_add
(
backward_state
.
dimA
,
backward_state
.
padA
,
backward_state
.
convstride1X1
,
backward_state
.
dilationA
,
backward_state
.
filterdimA1
,
backward_state
.
outdimA1
,
CUDNN_DATA_HALF
,
dx
,
w
,
dy1
,
dx_conv4
);
}
}
DEBUG_MSG
(
"[DEBUG] new dx : "
<<
grad_x
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new wgrad1 : "
<<
wgrad1
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new wgrad2 : "
<<
wgrad2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new wgrad3 : "
<<
wgrad3
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
if
(
stride_1X1
!=
1
||
backward_state
.
filterdimA3
[
0
]
!=
backward_state
.
dimA
[
1
])
{
DEBUG_MSG
(
"[DEBUG] new wgrad4 : "
<<
wgrad4
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
bottleneck_forward
,
"Bottleneck block forward"
);
m
.
def
(
"forward"
,
&
bottleneck_forward
,
"Bottleneck block forward"
);
m
.
def
(
"backward"
,
&
bottleneck_backward
,
"Bottleneck block backward"
);
m
.
def
(
"backward"
,
&
bottleneck_backward
,
"Bottleneck block backward"
);
m
.
def
(
"forward_init"
,
&
bottleneck_forward_init
,
"Bottleneck block init"
);
m
.
def
(
"forward_out1"
,
&
bottleneck_forward_out1
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2"
,
&
bottleneck_forward_out2
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_halo"
,
&
bottleneck_forward_out2_halo
,
"Bottleneck block forward"
);
m
.
def
(
"forward_rest"
,
&
bottleneck_forward_rest
,
"Bottleneck block forward"
);
m
.
def
(
"backward_init"
,
&
bottleneck_backward_init
,
"Bottleneck block backward init"
);
m
.
def
(
"backward_grad_out2"
,
&
bottleneck_backward_grad_out2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_rest"
,
&
bottleneck_backward_rest
,
"Bottleneck block backward"
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment