Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d934eca3
Unverified
Commit
d934eca3
authored
Sep 01, 2021
by
Thor Johnsen
Committed by
GitHub
Sep 01, 2021
Browse files
Merge pull request #1154 from NVIDIA/rework_spatial_bottleneck_code_split
Add functions to compute grad_out1, grad_out1_halo
parents
4d190db6
b6980a0d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
144 additions
and
52 deletions
+144
-52
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+27
-23
apex/contrib/csrc/bottleneck/bottleneck.cpp
apex/contrib/csrc/bottleneck/bottleneck.cpp
+117
-29
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
d934eca3
...
@@ -237,8 +237,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -237,8 +237,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
out1
=
outputs
[
0
]
out1
=
outputs
[
0
]
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
padded_out1
=
torch
.
empty
((
N
,
Hs
+
2
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
padded_out1
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream1
):
with
torch
.
cuda
.
stream
(
stream1
):
# copy halos to send buffer
# copy halos to send buffer
...
@@ -248,22 +246,17 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -248,22 +246,17 @@ class SpatialBottleneckFunction(torch.autograd.Function):
all_halos
=
torch
.
empty
((
N
,
2
*
spatial_group_size
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
all_halos
=
torch
.
empty
((
N
,
2
*
spatial_group_size
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
:(
i
+
1
)
*
2
,:,:]
for
i
in
range
(
spatial_group_size
)]
all_halos
=
[
all_halos
[:,
i
*
2
:(
i
+
1
)
*
2
,:,:]
for
i
in
range
(
spatial_group_size
)]
dist
.
all_gather
(
all_halos
,
send_halos
)
dist
.
all_gather
(
all_halos
,
send_halos
)
padded_out1_top_halo
=
padded_out1
[:,:
1
,:,:]
fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
if
local_rank
>
0
:
if
local_rank
>
0
:
top_halo
=
all_halos
[
local_rank
-
1
][:,
1
:,:,:]
top_halo
=
all_halos
[
local_rank
-
1
][:,
1
:,:,:]
padded_out1_top_halo
.
copy_
(
top_halo
)
fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
fat_top_halo
=
padded_out1
[:,:
3
,:,:]
fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_top_halo
,
args
)
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_halo
,
args
)
else
:
padded_out1_top_halo
.
zero_
()
padded_out1_btm_halo
=
padded_out1
[:,
Hs
+
1
:,:,:]
if
local_rank
<
spatial_group_size
-
1
:
if
local_rank
<
spatial_group_size
-
1
:
btm_halo
=
all_halos
[
local_rank
+
1
][:,:
1
,:,:]
btm_halo
=
all_halos
[
local_rank
+
1
][:,:
1
,:,:]
padded_out1_btm_halo
.
copy_
(
btm_halo
)
fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
fat_btm_halo
=
padded_out1
[:,
Hs
-
1
:,:,:]
fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_btm_halo
,
args
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_halo
,
args
)
else
:
padded_out1_btm_halo
.
zero_
()
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
out2
=
outputs
[
1
]
out2
=
outputs
[
1
]
if
local_rank
>
0
:
if
local_rank
>
0
:
...
@@ -272,10 +265,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -272,10 +265,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
if
spatial_group_size
>
1
:
# TODO: save halos for backward pass
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
padded_out1
]))
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
else
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
# save relu outputs for drelu
# save relu outputs for drelu
ctx
.
nhwc
=
nhwc
ctx
.
nhwc
=
nhwc
ctx
.
stride_1x1
=
stride_1x1
ctx
.
stride_1x1
=
stride_1x1
...
@@ -289,10 +280,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -289,10 +280,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
# only support dgrad
# only support dgrad
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_o
):
def
backward
(
ctx
,
grad_o
):
if
ctx
.
spatial_group_size
>
1
:
outputs
=
ctx
.
saved_tensors
[
-
3
:]
outputs
=
ctx
.
saved_tensors
[
-
4
:
-
1
]
else
:
outputs
=
ctx
.
saved_tensors
[
-
3
:]
if
ctx
.
downsample
:
if
ctx
.
downsample
:
grad_conv3
,
grad_conv4
=
drelu_dscale2
(
grad_o
,
outputs
[
2
],
ctx
.
saved_tensors
[
6
],
ctx
.
saved_tensors
[
11
])
grad_conv3
,
grad_conv4
=
drelu_dscale2
(
grad_o
,
outputs
[
2
],
ctx
.
saved_tensors
[
6
],
ctx
.
saved_tensors
[
11
])
...
@@ -315,7 +303,23 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -315,7 +303,23 @@ class SpatialBottleneckFunction(torch.autograd.Function):
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
# do halo exchange of grad_out2 here
# do halo exchange of grad_out2 here
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# need fast_bottleneck.backward_grad_out2_halo
# testing
N
,
H
,
W
,
C
=
grad_out2
.
shape
grad_out2_halo
=
torch
.
empty
([
N
,
3
,
W
,
C
],
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
grad_out2_halo
[:,:
1
,:,:].
zero_
()
grad_out2_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2_halo
)
# print("grad_out2_halo.shape = %s -> grad_out1_halo.shape = %s" % (str(list(grad_out2_halo.shape)), str(list(grad_out1_halo.shape))))
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# apply wgrad2 halos here
# no need for custom wgrad2_halo function, this is just a backwards data convolution
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# apply grad_out1 halos here
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
...
...
apex/contrib/csrc/bottleneck/bottleneck.cpp
View file @
d934eca3
...
@@ -1746,7 +1746,7 @@ std::vector<at::Tensor> bottleneck_forward_init(bool explicit_nhwc, int stride_1
...
@@ -1746,7 +1746,7 @@ std::vector<at::Tensor> bottleneck_forward_init(bool explicit_nhwc, int stride_1
std
::
vector
<
at
::
Tensor
>
outputs
;
std
::
vector
<
at
::
Tensor
>
outputs
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
printf
(
"outdim1 = (%d,%d,%d,%d)
\n
"
,
forward_state
.
outdim1
[
0
],
forward_state
.
outdim1
[
1
],
forward_state
.
outdim1
[
2
],
forward_state
.
outdim1
[
3
]);
//
printf("outdim1 = (%d,%d,%d,%d)\n",forward_state.outdim1[0],forward_state.outdim1[1],forward_state.outdim1[2],forward_state.outdim1[3]);
auto
out1
=
at
::
empty
(
forward_state
.
outdim1
,
inputs
[
0
].
type
(),
output_format
);
auto
out1
=
at
::
empty
(
forward_state
.
outdim1
,
inputs
[
0
].
type
(),
output_format
);
auto
out2
=
at
::
empty
(
forward_state
.
outdim2
,
inputs
[
0
].
type
(),
output_format
);
auto
out2
=
at
::
empty
(
forward_state
.
outdim2
,
inputs
[
0
].
type
(),
output_format
);
auto
out3
=
at
::
empty
(
forward_state
.
outdim3
,
inputs
[
0
].
type
(),
output_format
);
auto
out3
=
at
::
empty
(
forward_state
.
outdim3
,
inputs
[
0
].
type
(),
output_format
);
...
@@ -1837,12 +1837,12 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
...
@@ -1837,12 +1837,12 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
auto
out2
=
outputs
[
1
];
auto
out2
=
outputs
[
1
];
at
::
Half
*
y2
=
out2
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
y2
=
out2
.
data_ptr
<
at
::
Half
>
();
printf
(
"forward_state.outdimA1 = {%d,%d,%d,%d}
\n
"
,
forward_state
.
outdimA1
[
0
],
forward_state
.
outdimA1
[
1
],
forward_state
.
outdimA1
[
2
],
forward_state
.
outdimA1
[
3
]);
//
printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
printf
(
"forward_state.padA1 = {%d,%d}
\n
"
,
forward_state
.
padA1
[
0
],
forward_state
.
padA1
[
1
]);
//
printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
printf
(
"forward_state.convstrideA = {%d,%d}
\n
"
,
forward_state
.
convstrideA
[
0
],
forward_state
.
convstrideA
[
1
]);
//
printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
printf
(
"forward_state.dilationA = {%d,%d}
\n
"
,
forward_state
.
dilationA
[
0
],
forward_state
.
dilationA
[
1
]);
//
printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
printf
(
"forward_state.filterdimA2 = {%d,%d,%d,%d}
\n
"
,
forward_state
.
filterdimA2
[
0
],
forward_state
.
filterdimA2
[
1
],
forward_state
.
filterdimA2
[
2
],
forward_state
.
filterdimA2
[
3
]);
//
printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
printf
(
"forward_state.outdimA2 = {%d,%d,%d,%d}
\n
"
,
forward_state
.
outdimA2
[
0
],
forward_state
.
outdimA2
[
1
],
forward_state
.
outdimA2
[
2
],
forward_state
.
outdimA2
[
3
]);
//
printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation
(
forward_state
.
outdimA1
,
run_conv_scale_bias_add_activation
(
forward_state
.
outdimA1
,
forward_state
.
padA1
,
forward_state
.
padA1
,
forward_state
.
convstrideA
,
forward_state
.
convstrideA
,
...
@@ -1934,12 +1934,15 @@ struct bottleneck_backward_state {
...
@@ -1934,12 +1934,15 @@ struct bottleneck_backward_state {
int
axis
[
4
];
int
axis
[
4
];
int64_t
outdimA1
[
4
];
int64_t
outdimA1
[
4
];
// grad_out1
int64_t
outdimA2
[
4
];
int64_t
outdimA2
[
4
];
// grad_out2
int64_t
outdimA3
[
4
];
int64_t
outdimA3
[
4
];
int64_t
outdimA1h
[
4
];
// output: grad_out1 halo (H=3)
int64_t
outdimA2h
[
4
];
// input : grad_out2 halo cells (H=3)
int64_t
padA
[
2
];
int64_t
padA
[
2
];
int64_t
padA1
[
2
];
int64_t
padA1
[
2
];
int64_t
padA2
[
2
];
int64_t
dilationA
[
2
];
int64_t
dilationA
[
2
];
int64_t
convstrideA
[
2
];
int64_t
convstrideA
[
2
];
int64_t
convstride1X1
[
2
];
int64_t
convstride1X1
[
2
];
...
@@ -1947,6 +1950,8 @@ struct bottleneck_backward_state {
...
@@ -1947,6 +1950,8 @@ struct bottleneck_backward_state {
int64_t
outdim1
[
4
];
int64_t
outdim1
[
4
];
int64_t
outdim2
[
4
];
int64_t
outdim2
[
4
];
int64_t
outdim3
[
4
];
int64_t
outdim3
[
4
];
int64_t
outdim1h
[
4
];
int64_t
outdim2hh
[
4
];
void
init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
void
init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
// setup dimensions
// setup dimensions
...
@@ -1985,6 +1990,8 @@ struct bottleneck_backward_state {
...
@@ -1985,6 +1990,8 @@ struct bottleneck_backward_state {
outdimA1
[
0
]
=
outdimA1
[
1
]
=
outdimA1
[
2
]
=
outdimA1
[
3
]
=
0
;
outdimA1
[
0
]
=
outdimA1
[
1
]
=
outdimA1
[
2
]
=
outdimA1
[
3
]
=
0
;
outdimA2
[
0
]
=
outdimA2
[
1
]
=
outdimA2
[
2
]
=
outdimA2
[
3
]
=
0
;
outdimA2
[
0
]
=
outdimA2
[
1
]
=
outdimA2
[
2
]
=
outdimA2
[
3
]
=
0
;
outdimA3
[
0
]
=
outdimA3
[
1
]
=
outdimA3
[
2
]
=
outdimA3
[
3
]
=
0
;
outdimA3
[
0
]
=
outdimA3
[
1
]
=
outdimA3
[
2
]
=
outdimA3
[
3
]
=
0
;
outdimA1h
[
0
]
=
outdimA1h
[
1
]
=
outdimA1h
[
2
]
=
outdimA1h
[
3
]
=
0
;
outdimA2h
[
0
]
=
outdimA2h
[
1
]
=
outdimA2h
[
2
]
=
outdimA2h
[
3
]
=
0
;
// use these fixed value for test run
// use these fixed value for test run
padA
[
0
]
=
0
;
padA
[
1
]
=
0
;
padA
[
0
]
=
0
;
padA
[
1
]
=
0
;
...
@@ -2012,10 +2019,21 @@ struct bottleneck_backward_state {
...
@@ -2012,10 +2019,21 @@ struct bottleneck_backward_state {
outdimA3
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA2
[
dim
+
2
],
padA
[
dim
],
filterdimA3
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
outdimA3
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA2
[
dim
+
2
],
padA
[
dim
],
filterdimA3
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
if
(
dim
==
2
)
{
outdimA1h
[
dim
]
=
3
;
outdimA2h
[
dim
]
=
3
;
}
else
{
outdimA1h
[
dim
]
=
outdimA1
[
dim
];
outdimA2h
[
dim
]
=
outdimA2
[
dim
];
}
}
// Create output tensor in the correct shape in pytorch's view
// Create output tensor in the correct shape in pytorch's view
outdim1
[
0
]
=
outdim1
[
1
]
=
outdim1
[
2
]
=
outdim1
[
3
]
=
0
;
outdim1
[
0
]
=
outdim1
[
1
]
=
outdim1
[
2
]
=
outdim1
[
3
]
=
0
;
outdim2
[
0
]
=
outdim2
[
1
]
=
outdim2
[
2
]
=
outdim2
[
3
]
=
0
;
outdim2
[
0
]
=
outdim2
[
1
]
=
outdim2
[
2
]
=
outdim2
[
3
]
=
0
;
outdim3
[
0
]
=
outdim3
[
1
]
=
outdim3
[
2
]
=
outdim3
[
3
]
=
0
;
outdim3
[
0
]
=
outdim3
[
1
]
=
outdim3
[
2
]
=
outdim3
[
3
]
=
0
;
outdim1h
[
0
]
=
outdim1h
[
1
]
=
outdim1h
[
2
]
=
outdim1h
[
3
]
=
0
;
if
(
explicit_nhwc
)
{
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
0
]
=
0
;
axis
[
1
]
=
2
;
axis
[
1
]
=
2
;
...
@@ -2026,6 +2044,7 @@ struct bottleneck_backward_state {
...
@@ -2026,6 +2044,7 @@ struct bottleneck_backward_state {
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
outdim1h
[
dim
]
=
outdimA1h
[
axis
[
dim
]];
}
}
}
}
};
};
...
@@ -2117,7 +2136,78 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std
...
@@ -2117,7 +2136,78 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std
return
grad_out2
;
return
grad_out2
;
}
}
void
bottleneck_backward_rest
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
)
{
at
::
Tensor
bottleneck_backward_grad_out1
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
at
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
();
// dgrad
auto
grad_out1
=
at
::
empty
(
backward_state
.
outdim1
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
dy1
=
grad_out1
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu1
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
// fused dgrad
run_dconv_drelu_dscale
(
backward_state
.
outdimA1
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
backward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
dy1
,
w
,
dy2
,
z
,
relu1
);
return
grad_out1
;
}
// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C]
at
::
Tensor
bottleneck_backward_grad_out1_halo
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2_halo
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
at
::
Half
*
dy2h
=
grad_out2_halo
.
data_ptr
<
at
::
Half
>
();
// dgrad
auto
grad_out1_halo
=
at
::
empty
(
backward_state
.
outdim1h
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
dy1h
=
grad_out1_halo
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu1
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
// fused dgrad
//printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]);
//printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]);
run_dconv_drelu_dscale
(
backward_state
.
outdimA1h
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
backward_state
.
outdimA2h
,
CUDNN_DATA_HALF
,
dy1h
,
w
,
dy2h
,
z
,
relu1
);
return
grad_out1_halo
;
}
at
::
Tensor
bottleneck_backward_wgrad2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
...
@@ -2134,7 +2224,7 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
...
@@ -2134,7 +2224,7 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
auto
wgrad2
=
outputs
[
2
];
auto
wgrad2
=
outputs
[
2
];
at
::
Half
*
dw2
=
wgrad2
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dw2
=
wgrad2
.
data_ptr
<
at
::
Half
>
();
printf
(
"outdimA1 = (%d,%d,%d,%d)
\n
"
,
backward_state
.
outdimA1
[
0
],
backward_state
.
outdimA1
[
1
],
backward_state
.
outdimA1
[
2
],
backward_state
.
outdimA1
[
3
]);
//
printf("outdimA1 = (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]);
run_dconv
(
backward_state
.
outdimA1
,
run_dconv
(
backward_state
.
outdimA1
,
backward_state
.
padA1
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
convstrideA
,
...
@@ -2147,26 +2237,19 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
...
@@ -2147,26 +2237,19 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
dy2
,
dy2
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
return
wgrad2
;
}
void
bottleneck_backward_rest
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
,
at
::
Tensor
grad_out1
,
at
::
Tensor
wgrad2
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
// dgrad
a
uto
grad_out1
=
at
::
empty
(
backward_state
.
outdim1
,
inputs
[
0
].
type
(),
output_format
);
a
t
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
(
);
at
::
Half
*
dy1
=
grad_out1
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dy1
=
grad_out1
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu1
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
// fused dgrad
run_dconv_drelu_dscale
(
backward_state
.
outdimA1
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
backward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
dy1
,
w
,
dy2
,
z
,
relu1
);
/*
/*
// backward strided conv cannot be fused
// backward strided conv cannot be fused
...
@@ -2215,6 +2298,8 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
...
@@ -2215,6 +2298,8 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
// x used for dconv1 and dconv4 wgrad
// x used for dconv1 and dconv4 wgrad
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
NULL
;
if
(
stride_1X1
!=
1
||
backward_state
.
filterdimA3
[
0
]
!=
backward_state
.
dimA
[
1
]){
if
(
stride_1X1
!=
1
||
backward_state
.
filterdimA3
[
0
]
!=
backward_state
.
dimA
[
1
]){
w
=
inputs
[
14
].
data_ptr
<
at
::
Half
>
();
w
=
inputs
[
14
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dy_conv4
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dy_conv4
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
...
@@ -2327,5 +2412,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -2327,5 +2412,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"forward_rest"
,
&
bottleneck_forward_rest
,
"Bottleneck block forward"
);
m
.
def
(
"forward_rest"
,
&
bottleneck_forward_rest
,
"Bottleneck block forward"
);
m
.
def
(
"backward_init"
,
&
bottleneck_backward_init
,
"Bottleneck block backward init"
);
m
.
def
(
"backward_init"
,
&
bottleneck_backward_init
,
"Bottleneck block backward init"
);
m
.
def
(
"backward_grad_out2"
,
&
bottleneck_backward_grad_out2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_grad_out2"
,
&
bottleneck_backward_grad_out2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_grad_out1"
,
&
bottleneck_backward_grad_out1
,
"Bottleneck block backward"
);
m
.
def
(
"backward_grad_out1_halo"
,
&
bottleneck_backward_grad_out1_halo
,
"Bottleneck block backward"
);
m
.
def
(
"backward_wgrad2"
,
&
bottleneck_backward_wgrad2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_rest"
,
&
bottleneck_backward_rest
,
"Bottleneck block backward"
);
m
.
def
(
"backward_rest"
,
&
bottleneck_backward_rest
,
"Bottleneck block backward"
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment