Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
834b1d01
Commit
834b1d01
authored
Mar 29, 2022
by
Thor Johnsen
Browse files
Concatenate out1 with halos for backward
parent
e5d0be82
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
85 additions
and
18 deletions
+85
-18
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+38
-18
apex/contrib/bottleneck/halo_exchangers.py
apex/contrib/bottleneck/halo_exchangers.py
+1
-0
apex/contrib/csrc/bottleneck/bottleneck.cpp
apex/contrib/csrc/bottleneck/bottleneck.cpp
+46
-0
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
834b1d01
...
@@ -220,10 +220,11 @@ class Bottleneck(torch.nn.Module):
...
@@ -220,10 +220,11 @@ class Bottleneck(torch.nn.Module):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_
stream
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_
method
,
nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
stream1
=
spatial_halo_exchanger
.
stream1
stream1
=
spatial_halo_exchanger
.
stream1
stream2
=
spatial_halo_exchanger
.
stream2
stream2
=
spatial_halo_exchanger
.
stream2
stream3
=
spatial_halo_exchanger
.
stream3
# TODO: clean up order of tensors
# TODO: clean up order of tensors
args
=
[
x
,
*
conv
[
0
:
3
],
*
scale
[
0
:
3
],
*
bias
[
0
:
3
]]
args
=
[
x
,
*
conv
[
0
:
3
],
*
scale
[
0
:
3
],
*
bias
[
0
:
3
]]
...
@@ -239,13 +240,19 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -239,13 +240,19 @@ class SpatialBottleneckFunction(torch.autograd.Function):
outputs
=
fast_bottleneck
.
forward_init
(
nhwc
,
stride_1x1
,
args
)
outputs
=
fast_bottleneck
.
forward_init
(
nhwc
,
stride_1x1
,
args
)
fast_bottleneck
.
forward_out1
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out1
(
nhwc
,
stride_1x1
,
args
,
outputs
)
# do halo exchange for outputs[0] (out1)
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
out1
=
outputs
[
0
]
out1
=
outputs
[
0
]
# TODO: This assumes explicit nhwc
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
out1_pad
=
torch
.
empty
([
N
,
Hs
+
2
,
W
,
C
],
dtype
=
out1
.
dtype
,
device
=
'cuda'
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream3
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream3
):
out1_pad
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
with
torch
.
cuda
.
stream
(
stream1
):
with
torch
.
cuda
.
stream
(
stream1
):
top_out1_halo
,
btm_out1_halo
=
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:
1
,:,:],
out1
[:,
Hs
-
1
:,:,:])
top_out1_halo
=
out1_pad
[:,:
1
,:,:]
btm_out1_halo
=
out1_pad
[:,
Hs
+
1
:
Hs
+
2
,:,:]
spatial_halo_exchanger
.
left_right_halo_exchange
(
out1
[:,:
1
,:,:],
out1
[:,
Hs
-
1
:,:,:],
top_out1_halo
,
btm_out1_halo
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
if
spatial_group_rank
<
spatial_group_size
-
1
:
stream2
.
wait_stream
(
stream1
)
stream2
.
wait_stream
(
stream1
)
with
torch
.
cuda
.
stream
(
stream2
):
with
torch
.
cuda
.
stream
(
stream2
):
...
@@ -253,12 +260,18 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -253,12 +260,18 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
btm_fat_halo
,
args
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
btm_fat_halo
,
args
)
else
:
with
torch
.
cuda
.
stream
(
stream2
):
btm_out1_halo
.
zero_
()
if
spatial_group_rank
>
0
:
if
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
stream1
):
with
torch
.
cuda
.
stream
(
stream1
):
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
top_fat_halo
,
args
)
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
top_fat_halo
,
args
)
else
:
with
torch
.
cuda
.
stream
(
stream1
):
top_out1_halo
.
zero_
()
inc
.
add_delay
(
10
)
inc
.
add_delay
(
10
)
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
nhwc
,
stride_1x1
,
args
,
outputs
)
...
@@ -272,11 +285,12 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -272,11 +285,12 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if
spatial_group_rank
<
spatial_group_size
-
1
:
if
spatial_group_rank
<
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
# save halos for backward pass
# save halos for backward pass
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
top_
out1_
halo
,
btm_out1_halo
]))
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
out1_
pad
,
]))
else
:
else
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
# save relu outputs for drelu
# save relu outputs for drelu
...
@@ -286,8 +300,10 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -286,8 +300,10 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
ctx
.
spatial_group_rank
=
spatial_group_rank
ctx
.
spatial_group_rank
=
spatial_group_rank
ctx
.
spatial_halo_exchanger
=
spatial_halo_exchanger
ctx
.
spatial_halo_exchanger
=
spatial_halo_exchanger
ctx
.
spatial_method
=
spatial_method
ctx
.
stream1
=
stream1
ctx
.
stream1
=
stream1
ctx
.
stream2
=
stream2
ctx
.
stream2
=
stream2
ctx
.
stream3
=
stream3
return
outputs
[
2
]
return
outputs
[
2
]
# backward relu is not exposed, MUL with mask used now
# backward relu is not exposed, MUL with mask used now
...
@@ -295,9 +311,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -295,9 +311,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_o
):
def
backward
(
ctx
,
grad_o
):
if
ctx
.
spatial_group_size
>
1
:
if
ctx
.
spatial_group_size
>
1
:
top_out1_halo
=
ctx
.
saved_tensors
[
-
2
]
out1_pad
=
ctx
.
saved_tensors
[
-
1
]
btm_out1_halo
=
ctx
.
saved_tensors
[
-
1
]
outputs
=
ctx
.
saved_tensors
[
-
4
:
-
1
]
outputs
=
ctx
.
saved_tensors
[
-
5
:
-
2
]
else
:
else
:
outputs
=
ctx
.
saved_tensors
[
-
3
:]
outputs
=
ctx
.
saved_tensors
[
-
3
:]
...
@@ -353,19 +368,24 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -353,19 +368,24 @@ class SpatialBottleneckFunction(torch.autograd.Function):
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
top_grad_out1_halo
=
top_grad_out1_halo
[:,
1
:
2
,:,:]
inc
.
add_delay
(
10
)
inc
.
add_delay
(
10
)
if
ctx
.
spatial_group_size
>
1
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2_pad
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
out1_pad
,
grad_out2
)
else
:
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# compute wgrad2 for internal cells
# compute wgrad2 for internal cells
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
#
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos
# apply wgrad2 halos
if
ctx
.
spatial_group_size
>
1
:
#
if ctx.spatial_group_size > 1:
if
ctx
.
spatial_group_rank
>
0
:
#
if ctx.spatial_group_rank > 0:
top_grad2_halo
=
grad_out2
[:,:
1
,:,:]
#
top_grad2_halo = grad_out2[:,:1,:,:]
top_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
top_out1_halo
,
top_grad2_halo
)
#
top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
wgrad2
[:,:
1
,:,:].
add_
(
top_wgrad2_halo
)
#
wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
if
ctx
.
spatial_group_rank
<
ctx
.
spatial_group_size
-
1
:
#
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
btm_grad2_halo
=
grad_out2
[:,
-
1
:,:,:]
#
btm_grad2_halo = grad_out2[:,-1:,:,:]
btm_wgrad2_halo
=
fast_bottleneck
.
backward_wgrad2_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
btm_out1_halo
,
btm_grad2_halo
)
#
btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
wgrad2
[:,
-
1
:,:,:].
add_
(
btm_wgrad2_halo
)
#
wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# compute grad_out1 for internal cells
# compute grad_out1 for internal cells
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
...
@@ -456,7 +476,7 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -456,7 +476,7 @@ class SpatialBottleneck(torch.nn.Module):
# spatial communicator
# spatial communicator
if
spatial_parallel_args
is
None
:
if
spatial_parallel_args
is
None
:
self
.
spatial_parallel_args
=
(
1
,
0
,
None
,
None
,
None
)
self
.
spatial_parallel_args
=
(
1
,
0
,
None
,
None
,
0
)
else
:
else
:
self
.
spatial_parallel_args
=
spatial_parallel_args
self
.
spatial_parallel_args
=
spatial_parallel_args
return
return
...
...
apex/contrib/bottleneck/halo_exchangers.py
View file @
834b1d01
...
@@ -12,6 +12,7 @@ class HaloExchanger(object):
...
@@ -12,6 +12,7 @@ class HaloExchanger(object):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream2
=
torch
.
cuda
.
Stream
()
self
.
stream2
=
torch
.
cuda
.
Stream
()
self
.
stream3
=
torch
.
cuda
.
Stream
()
class
HaloExchangerNoComm
(
HaloExchanger
):
class
HaloExchangerNoComm
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
...
...
apex/contrib/csrc/bottleneck/bottleneck.cpp
View file @
834b1d01
...
@@ -1936,6 +1936,7 @@ struct bottleneck_backward_state {
...
@@ -1936,6 +1936,7 @@ struct bottleneck_backward_state {
int
axis
[
4
];
int
axis
[
4
];
int64_t
outdimA1
[
4
];
// grad_out1
int64_t
outdimA1
[
4
];
// grad_out1
int64_t
outdimA1b
[
4
];
// out1_pad
int64_t
outdimA2
[
4
];
// grad_out2
int64_t
outdimA2
[
4
];
// grad_out2
int64_t
outdimA3
[
4
];
int64_t
outdimA3
[
4
];
int64_t
outdimA1h
[
4
];
// output: grad_out1 halo (H=3)
int64_t
outdimA1h
[
4
];
// output: grad_out1 halo (H=3)
...
@@ -1953,6 +1954,7 @@ struct bottleneck_backward_state {
...
@@ -1953,6 +1954,7 @@ struct bottleneck_backward_state {
int64_t
filterdim2hh
[
4
];
// Cin,1,3,Cout
int64_t
filterdim2hh
[
4
];
// Cin,1,3,Cout
int64_t
outdim1
[
4
];
int64_t
outdim1
[
4
];
int64_t
outdim1b
[
4
];
int64_t
outdim2
[
4
];
int64_t
outdim2
[
4
];
int64_t
outdim3
[
4
];
int64_t
outdim3
[
4
];
int64_t
outdim1h
[
4
];
int64_t
outdim1h
[
4
];
...
@@ -2001,6 +2003,7 @@ struct bottleneck_backward_state {
...
@@ -2001,6 +2003,7 @@ struct bottleneck_backward_state {
// output dim in n,c,h,w used by backend
// output dim in n,c,h,w used by backend
outdimA1
[
0
]
=
outdimA1
[
1
]
=
outdimA1
[
2
]
=
outdimA1
[
3
]
=
0
;
outdimA1
[
0
]
=
outdimA1
[
1
]
=
outdimA1
[
2
]
=
outdimA1
[
3
]
=
0
;
outdimA1b
[
0
]
=
outdimA1b
[
1
]
=
outdimA1b
[
2
]
=
outdimA1b
[
3
]
=
0
;
outdimA2
[
0
]
=
outdimA2
[
1
]
=
outdimA2
[
2
]
=
outdimA2
[
3
]
=
0
;
outdimA2
[
0
]
=
outdimA2
[
1
]
=
outdimA2
[
2
]
=
outdimA2
[
3
]
=
0
;
outdimA3
[
0
]
=
outdimA3
[
1
]
=
outdimA3
[
2
]
=
outdimA3
[
3
]
=
0
;
outdimA3
[
0
]
=
outdimA3
[
1
]
=
outdimA3
[
2
]
=
outdimA3
[
3
]
=
0
;
outdimA1h
[
0
]
=
outdimA1h
[
1
]
=
outdimA1h
[
2
]
=
outdimA1h
[
3
]
=
0
;
outdimA1h
[
0
]
=
outdimA1h
[
1
]
=
outdimA1h
[
2
]
=
outdimA1h
[
3
]
=
0
;
...
@@ -2022,6 +2025,13 @@ struct bottleneck_backward_state {
...
@@ -2022,6 +2025,13 @@ struct bottleneck_backward_state {
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
for
(
int
dim
=
0
;
dim
<
2
;
dim
++
)
{
outdimA1
[
dim
+
2
]
=
getFwdConvOutputDim
(
dimA
[
dim
+
2
],
padA
[
dim
],
filterdimA1
[
dim
+
2
],
convstride1X1
[
dim
],
dilationA
[
dim
]);
outdimA1
[
dim
+
2
]
=
getFwdConvOutputDim
(
dimA
[
dim
+
2
],
padA
[
dim
],
filterdimA1
[
dim
+
2
],
convstride1X1
[
dim
],
dilationA
[
dim
]);
}
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
if
(
dim
==
2
)
{
outdimA1b
[
dim
]
=
outdimA1
[
dim
]
+
2
;
}
else
{
outdimA1b
[
dim
]
=
outdimA1
[
dim
];
}
}
outdimA2
[
0
]
=
outdimA1
[
0
];
outdimA2
[
0
]
=
outdimA1
[
0
];
outdimA2
[
1
]
=
filterdimA2
[
0
];
outdimA2
[
1
]
=
filterdimA2
[
0
];
...
@@ -2051,6 +2061,7 @@ struct bottleneck_backward_state {
...
@@ -2051,6 +2061,7 @@ struct bottleneck_backward_state {
// Create output tensor in the correct shape in pytorch's view
// Create output tensor in the correct shape in pytorch's view
outdim1
[
0
]
=
outdim1
[
1
]
=
outdim1
[
2
]
=
outdim1
[
3
]
=
0
;
outdim1
[
0
]
=
outdim1
[
1
]
=
outdim1
[
2
]
=
outdim1
[
3
]
=
0
;
outdim1b
[
0
]
=
outdim1b
[
1
]
=
outdim1b
[
2
]
=
outdim1b
[
3
]
=
0
;
outdim2
[
0
]
=
outdim2
[
1
]
=
outdim2
[
2
]
=
outdim2
[
3
]
=
0
;
outdim2
[
0
]
=
outdim2
[
1
]
=
outdim2
[
2
]
=
outdim2
[
3
]
=
0
;
outdim3
[
0
]
=
outdim3
[
1
]
=
outdim3
[
2
]
=
outdim3
[
3
]
=
0
;
outdim3
[
0
]
=
outdim3
[
1
]
=
outdim3
[
2
]
=
outdim3
[
3
]
=
0
;
outdim1h
[
0
]
=
outdim1h
[
1
]
=
outdim1h
[
2
]
=
outdim1h
[
3
]
=
0
;
outdim1h
[
0
]
=
outdim1h
[
1
]
=
outdim1h
[
2
]
=
outdim1h
[
3
]
=
0
;
...
@@ -2063,6 +2074,7 @@ struct bottleneck_backward_state {
...
@@ -2063,6 +2074,7 @@ struct bottleneck_backward_state {
}
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim1b
[
dim
]
=
outdimA1b
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
outdim1h
[
dim
]
=
outdimA1h
[
axis
[
dim
]];
outdim1h
[
dim
]
=
outdimA1h
[
axis
[
dim
]];
...
@@ -2234,6 +2246,39 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1
...
@@ -2234,6 +2246,39 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1
return
grad_out1_halo
;
return
grad_out1_halo
;
}
}
at
::
Tensor
bottleneck_backward_wgrad2_pad
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
input
,
at
::
Tensor
grad_out2
)
{
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
at
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
();
// dconv2+drelu1+dscale1
at
::
Half
*
conv_in
=
input
.
data_ptr
<
at
::
Half
>
();
// wgrad
auto
wgrad2
=
outputs
[
2
];
at
::
Half
*
dw2
=
wgrad2
.
data_ptr
<
at
::
Half
>
();
//printf("outdimA1b = (%d,%d,%d,%d)\n",backward_state.outdimA1b[0],backward_state.outdimA1b[1],backward_state.outdimA1b[2],backward_state.outdimA1b[3]);
//printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]);
run_dconv
(
backward_state
.
outdimA1b
,
// conv_in.shape (including H halos)
backward_state
.
padA2
,
// 0, 1
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
// dw2.shape
backward_state
.
outdimA2
,
// dy2.shape
CUDNN_DATA_HALF
,
conv_in
,
dw2
,
dy2
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
DEBUG_MSG
(
"[DEBUG] new wgrad2 : "
<<
wgrad2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
return
wgrad2
;
}
at
::
Tensor
bottleneck_backward_wgrad2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
)
{
at
::
Tensor
bottleneck_backward_wgrad2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
...
@@ -2480,6 +2525,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -2480,6 +2525,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"backward_grad_out2"
,
&
bottleneck_backward_grad_out2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_grad_out2"
,
&
bottleneck_backward_grad_out2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_grad_out1"
,
&
bottleneck_backward_grad_out1
,
"Bottleneck block backward"
);
m
.
def
(
"backward_grad_out1"
,
&
bottleneck_backward_grad_out1
,
"Bottleneck block backward"
);
m
.
def
(
"backward_grad_out1_halo"
,
&
bottleneck_backward_grad_out1_halo
,
"Bottleneck block backward"
);
m
.
def
(
"backward_grad_out1_halo"
,
&
bottleneck_backward_grad_out1_halo
,
"Bottleneck block backward"
);
m
.
def
(
"backward_wgrad2_pad"
,
&
bottleneck_backward_wgrad2_pad
,
"Bottleneck block backward"
);
m
.
def
(
"backward_wgrad2"
,
&
bottleneck_backward_wgrad2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_wgrad2"
,
&
bottleneck_backward_wgrad2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_wgrad2_halo"
,
&
bottleneck_backward_wgrad2_halo
,
"Bottleneck block backward"
);
m
.
def
(
"backward_wgrad2_halo"
,
&
bottleneck_backward_wgrad2_halo
,
"Bottleneck block backward"
);
m
.
def
(
"backward_rest"
,
&
bottleneck_backward_rest
,
"Bottleneck block backward"
);
m
.
def
(
"backward_rest"
,
&
bottleneck_backward_rest
,
"Bottleneck block backward"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment