Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
60000f73
Commit
60000f73
authored
Mar 31, 2022
by
Thor Johnsen
Browse files
Add halo correction using new cudnn masking feature
parent
9c16d945
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
790 additions
and
34 deletions
+790
-34
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+56
-28
apex/contrib/csrc/bottleneck/bottleneck.cpp
apex/contrib/csrc/bottleneck/bottleneck.cpp
+734
-6
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
60000f73
...
...
@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
,
explicit_nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
,
explicit_nhwc
,
stride_1x1
,
scale
,
bias
,
thresholdTop
,
thresholdBottom
,
x
,
*
conv
):
if
spatial_group_size
>
1
:
stream1
=
spatial_halo_exchanger
.
stream1
stream2
=
spatial_halo_exchanger
.
stream2
...
...
@@ -271,57 +271,75 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if
spatial_group_rank
<
spatial_group_size
-
1
:
stream2
.
wait_stream
(
stream1
)
with
torch
.
cuda
.
stream
(
stream2
):
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
if
explicit_nhwc
:
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
else
:
btm_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
btm_fat_halo
[:,:,
0
:
2
,:].
copy_
(
out1
[:,:,
Hs
-
2
:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
btm_fat_halo
,
args
)
if
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
stream1
):
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
if
explicit_nhwc
:
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
else
:
top_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:,:
2
,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
inc
.
add_delay
(
10
)
elif
spatial_method
==
2
:
# wait for halo transfer to finish before doing a full convolution of padded x
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
fast_bottleneck
.
forward_out2_pad
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
else
:
assert
(
False
),
"spatial_method must be 1 or 2"
elif
spatial_method
!=
2
and
spatial_method
!=
3
:
assert
(
False
),
"spatial_method must be 1, 2 or 3"
if
spatial_group_size
<=
1
or
spatial_method
==
1
:
if
spatial_group_size
<=
1
:
fast_bottleneck
.
forward_out2
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
elif
spatial_method
==
1
:
fast_bottleneck
.
forward_out2
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
elif
spatial_method
==
2
:
# wait for halo transfer to finish before doing a full convolution of padded x
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
fast_bottleneck
.
forward_out2_pad
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
elif
spatial_method
==
3
:
fast_bottleneck
.
forward_out2_mask
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
,
thresholdTop
,
thresholdBottom
)
# compute halo cells for outputs[1] (out2)
if
spatial_group_size
>
1
and
spatial_method
==
1
:
if
spatial_group_size
>
1
:
out2
=
outputs
[
1
]
if
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
if
explicit_nhwc
:
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
else
:
out2
[:,:,:
1
,:].
copy_
(
top_out2
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
if
explicit_nhwc
:
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
else
:
out2
[:,:,
Hs
-
1
:,:].
copy_
(
btm_out2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
if
explicit_nhwc
:
top_out2_halo
=
out2
[:,:
1
,:,:]
btm_out2_halo
=
out2
[:,
Hs
-
1
:,:,:]
else
:
top_out2_halo
=
out2
[:,:,:
1
,:]
btm_out2_halo
=
out2
[:,:,
Hs
-
1
:,:]
if
spatial_method
==
1
:
if
spatial_group_rank
>
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
top_out2_halo
.
copy_
(
top_out2
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
btm_out2_halo
.
copy_
(
btm_out2
)
elif
spatial_method
==
3
:
if
spatial_group_rank
>
0
:
w1by3
=
args
[
2
][:,:,
2
:
3
,:].
contiguous
(
memory_format
=
torch
.
preserve
)
top_out1_halo
=
top_out1_halo
.
contiguous
(
memory_format
=
memory_format
)
top_out2
=
fast_bottleneck
.
forward_out2_halo_corr
(
explicit_nhwc
,
top_out1_halo
,
args
,
w1by3
,
top_out2_halo
.
contiguous
(
memory_format
=
memory_format
))
top_out2_halo
.
copy_
(
top_out2
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
w1by3
=
args
[
2
][:,:,:
1
,:].
contiguous
(
memory_format
=
torch
.
preserve
)
btm_out1_halo
=
btm_out1_halo
.
contiguous
(
memory_format
=
memory_format
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo_corr
(
explicit_nhwc
,
btm_out1_halo
,
args
,
w1by3
,
btm_out2_halo
.
contiguous
(
memory_format
=
memory_format
))
btm_out2_halo
.
copy_
(
btm_out2
)
fast_bottleneck
.
forward_rest
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
# save halos for backward pass
if
spatial_group_size
>
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
if
spatial_method
!=
2
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
out1_pad
,]))
else
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
...
...
@@ -460,7 +478,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck
.
backward_rest
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
wgrad2_stream
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
spatial_bottleneck_function
=
SpatialBottleneckFunction
.
apply
...
...
@@ -515,6 +533,8 @@ class SpatialBottleneck(torch.nn.Module):
for
w
in
self
.
w_conv
:
kaiming_uniform_
(
w
,
a
=
1
)
self
.
thresholdTop
,
self
.
thresholdBottom
=
None
,
None
# TODO: prevent unsupported case usage
# support cases
# native cudnn
...
...
@@ -536,6 +556,14 @@ class SpatialBottleneck(torch.nn.Module):
def
forward
(
self
,
x
):
if
self
.
use_cudnn
:
if
self
.
thresholdTop
is
None
:
spatial_group_size
,
spatial_group_rank
,
_
,
_
,
_
=
self
.
spatial_parallel_args
if
self
.
explicit_nhwc
:
N
,
H
,
W
,
C
=
list
(
x
.
shape
)
else
:
N
,
C
,
H
,
W
=
list
(
x
.
shape
)
self
.
thresholdTop
=
torch
.
tensor
([
1
if
spatial_group_rank
>
0
else
0
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
self
.
thresholdBottom
=
torch
.
tensor
([
H
-
2
if
spatial_group_rank
<
spatial_group_size
-
1
else
H
-
1
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
# calculate scale/bias from registered buffers
# TODO: make this better
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
...
...
@@ -548,7 +576,7 @@ class SpatialBottleneck(torch.nn.Module):
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_parallel_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_parallel_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
self
.
thresholdTop
,
self
.
thresholdBottom
,
x
,
*
self
.
w_conv
)
return
out
if
self
.
explicit_nhwc
:
...
...
apex/contrib/csrc/bottleneck/bottleneck.cpp
View file @
60000f73
...
...
@@ -102,6 +102,13 @@ enum {
AFTERCONV_TENSOR
,
OPTIONAL
,
AFTEROPT_TENSOR
,
AFTERACT_TENSOR
,
GEN_INDEX_TENSOR
,
MASK_TOP_TENSOR
,
MASK_BOTTOM_TENSOR
,
MASK_TENSOR
,
THRESHOLD_TOP_TENSOR
,
THRESHOLD_BOTTOM_TENSOR
,
};
using
common_conv_descriptors
=
...
...
@@ -173,11 +180,11 @@ using common_convbias_descriptors = std::tuple<cudnn_frontend::Tensor,
common_convbias_descriptors
create_conv_bias_add_act_descriptors
(
int64_t
*
x_dim_padded
,
int64_t
*
padA
,
int64_t
*
convstrideA
,
int64_t
*
dilationA
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
cudnnDataType_t
dataType
)
{
int64_t
*
convstrideA
,
int64_t
*
dilationA
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
cudnnDataType_t
dataType
)
{
const
int
convDim
=
2
;
int64_t
b_dim_padded
[
4
];
...
...
@@ -190,6 +197,7 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded,
int64_t
y_stride_padded
[
4
];
int64_t
w_stride_padded
[
4
];
int64_t
b_stride_padded
[
4
];
int64_t
threshold_stride
[
4
];
generateStrides
(
w_dim_padded
,
w_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
x_dim_padded
,
x_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
...
...
@@ -272,6 +280,183 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded,
.
build
());
}
using
masked_convbias_descriptors
=
std
::
tuple
<
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
>
;
masked_convbias_descriptors
create_conv_bias_add_act_mask_descriptors
(
int64_t
*
x_dim_padded
,
int64_t
*
padA
,
int64_t
*
convstrideA
,
int64_t
*
dilationA
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
int64_t
*
threshold_dim
,
cudnnDataType_t
dataType
)
{
const
int
convDim
=
2
;
int64_t
b_dim_padded
[
4
];
b_dim_padded
[
0
]
=
1
;
b_dim_padded
[
1
]
=
y_dim_padded
[
1
];
b_dim_padded
[
2
]
=
1
;
b_dim_padded
[
3
]
=
1
;
int64_t
x_stride_padded
[
4
];
int64_t
y_stride_padded
[
4
];
int64_t
w_stride_padded
[
4
];
int64_t
b_stride_padded
[
4
];
int64_t
threshold_stride
[
4
];
generateStrides
(
w_dim_padded
,
w_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
x_dim_padded
,
x_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
y_dim_padded
,
y_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
b_dim_padded
,
b_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
threshold_dim
,
threshold_stride
,
4
,
CUDNN_TENSOR_NHWC
);
return
masked_convbias_descriptors
(
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
x_dim_padded
)
.
setStrides
(
4
,
x_stride_padded
)
.
setId
(
'x'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'y'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
w_dim_padded
)
.
setStrides
(
4
,
w_stride_padded
)
.
setId
(
'w'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
b_dim_padded
)
.
setStrides
(
4
,
b_stride_padded
)
.
setId
(
'z'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
b_dim_padded
)
.
setStrides
(
4
,
b_stride_padded
)
.
setId
(
'b'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setVirtual
()
.
setId
(
'A'
)
// after add
.
setAlignment
(
16
)
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setVirtual
()
.
setId
(
'B'
)
// after bias
.
setAlignment
(
16
)
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'C'
)
// after conv
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'i'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'D'
)
// after optional add
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'E'
)
// after act for masked
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'I'
)
// output of the gen index operation
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_INT32
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'm'
)
// top half of the mask created after the less than
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_BOOLEAN
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'n'
)
// bottom half of the mask
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_BOOLEAN
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'M'
)
// OR of the top and bottom masks
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_BOOLEAN
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
threshold_dim
)
.
setStrides
(
4
,
threshold_stride
)
.
setId
(
't'
)
// threshold for creating the top mask
.
setAlignment
(
16
)
.
setDataType
(
CUDNN_DATA_INT32
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
threshold_dim
)
.
setStrides
(
4
,
threshold_stride
)
.
setId
(
'u'
)
// threshold for creating the bottom mask
.
setAlignment
(
16
)
.
setDataType
(
CUDNN_DATA_INT32
)
.
build
());
}
// tensor descriptors used for dgrad
enum
{
X_OR_DX_TENSOR
,
...
...
@@ -593,7 +778,7 @@ run_conv_scale_bias_add_activation(int64_t* x_dim_padded,
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
.
setHandle
(
handle_
)
.
setOperationGraph
(
devPtrI
?
ops
.
size
()
:
4
,
ops
.
data
())
.
setOperationGraph
(
devPtrI
?
ops
.
size
()
:
ops
.
size
()
-
1
,
ops
.
data
())
.
build
();
// Create string encoding for plan caching
...
...
@@ -627,6 +812,458 @@ run_conv_scale_bias_add_activation(int64_t* x_dim_padded,
}
}
void
run_conv_add_scale_bias_activation
(
int64_t
*
x_dim_padded
,
int64_t
*
pad
,
int64_t
*
convstride
,
int64_t
*
dilation
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
cudnnDataType_t
dataType
,
at
::
Half
*
devPtrX
,
at
::
Half
*
devPtrW
,
at
::
Half
*
devPtrY
,
at
::
Half
*
devPtrZ
,
at
::
Half
*
devPtrB
,
at
::
Half
*
devPtrI
)
{
cudnnHandle_t
handle_
=
torch
::
native
::
getCudnnHandle
();
std
::
stringstream
log_buf
;
try
{
int
convDim
=
2
;
// Creates the necessary tensor descriptors
common_convbias_descriptors
tensors
=
create_conv_bias_add_act_descriptors
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
y_dim_padded
,
dataType
);
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
X_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
Y_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
W_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
Z_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
B_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERADD_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERBIAS_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
OPTIONAL
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTEROPT_TENSOR
>
(
tensors
).
describe
());
// Define the add operation
auto
scaleDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_MUL
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scaleDesc
.
describe
());
// Define the bias operation
auto
biasDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_ADD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
biasDesc
.
describe
());
// optional add
auto
addDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_ADD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
addDesc
.
describe
());
// Define the activation operation
auto
actDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_RELU_FWD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
actDesc
.
describe
());
// Define the convolution problem
auto
convDesc
=
cudnn_frontend
::
ConvDescBuilder
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
setMathMode
(
CUDNN_CROSS_CORRELATION
)
.
setNDims
(
convDim
)
.
setStrides
(
convDim
,
convstride
)
.
setPrePadding
(
convDim
,
pad
)
.
setPostPadding
(
convDim
,
pad
)
.
setDilation
(
convDim
,
dilation
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
convDesc
.
describe
());
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
// Create a convolution Node
auto
conv_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
X_TENSOR
>
(
tensors
))
.
setwDesc
(
std
::
get
<
W_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
))
.
setcDesc
(
convDesc
)
.
setAlpha
(
alpha
)
.
setBeta
(
beta
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
conv_op
.
describe
());
// create an add node.
auto
add_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
conv_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
OPTIONAL
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTEROPT_TENSOR
>
(
tensors
))
.
setpwDesc
(
addDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
add_op
.
describe
());
// Create a Add Node with scaling parameters.
auto
scale_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
add_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
Z_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERADD_TENSOR
>
(
tensors
))
.
setpwDesc
(
scaleDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scale_op
.
describe
());
// Create a Bias Node.
auto
bias_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
scale_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
B_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERBIAS_TENSOR
>
(
tensors
))
.
setpwDesc
(
biasDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
bias_op
.
describe
());
// Create an Activation Node.
auto
act_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
bias_op
.
getOutputTensor
())
.
setyDesc
(
std
::
get
<
Y_TENSOR
>
(
tensors
))
.
setpwDesc
(
actDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
act_op
.
describe
());
// Create an Operation Graph. In this case it is convolution add bias activation
std
::
array
<
cudnn_frontend
::
Operation
const
*
,
5
>
ops
=
{
&
conv_op
,
&
add_op
,
&
scale_op
,
&
bias_op
,
&
act_op
};
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
.
setHandle
(
handle_
)
.
setOperationGraph
(
ops
.
size
(),
ops
.
data
())
.
build
();
// Create string encoding for plan caching
auto
cache_string
=
getConvFusionString
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
dataType
,
opGraph
.
getTag
());
DEBUG_CUDNN_MSG
(
log_buf
,
"[convstring] "
<<
cache_string
);
auto
&
plan
=
getOrCreatePlan
(
handle_
,
log_buf
,
opGraph
,
cache_string
);
DEBUG_CUDNN_MSG
(
log_buf
,
"Plan tag: "
<<
plan
.
getTag
());
auto
workspace_size
=
plan
.
getWorkspaceSize
();
DEBUG_CUDNN_MSG
(
log_buf
,
plan
.
describe
()
<<
" requires workspace "
<<
workspace_size
);
void
*
workspace_ptr
=
nullptr
;
auto
workspace_tensor
=
at
::
empty
({(
workspace_size
+
3
)
/
4
},
at
::
TensorOptions
(
at
::
kCUDA
).
dtype
(
at
::
kFloat
));
if
(
workspace_size
>
0
)
{
workspace_ptr
=
workspace_tensor
.
data_ptr
<
float
>
();
}
void
*
data_ptrs
[]
=
{
devPtrX
,
devPtrY
,
devPtrW
,
devPtrZ
,
devPtrB
,
devPtrI
};
int64_t
uids
[]
=
{
'x'
,
'y'
,
'w'
,
'z'
,
'b'
,
'i'
};
auto
variantPack
=
cudnn_frontend
::
VariantPackBuilder
()
.
setWorkspacePointer
(
workspace_ptr
)
.
setDataPointers
(
6
,
data_ptrs
)
.
setUids
(
6
,
uids
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
"variantPack "
<<
variantPack
.
describe
());
cudnnStatus_t
status
=
cudnnBackendExecute
(
handle_
,
plan
.
get_raw_desc
(),
variantPack
.
get_raw_desc
());
checkCudnnErr
(
status
);
cudnn_frontend
::
throw_if
([
status
]()
{
return
(
status
!=
CUDNN_STATUS_SUCCESS
);
},
"Plan execute error"
,
status
);
}
catch
(
cudnn_frontend
::
cudnnException
e
)
{
std
::
cout
<<
log_buf
.
str
()
<<
"[ERROR] Exception "
<<
e
.
what
()
<<
std
::
endl
;
}
}
void
run_conv_scale_bias_add_activation_mask
(
int64_t
*
x_dim_padded
,
int64_t
*
pad
,
int64_t
*
convstride
,
int64_t
*
dilation
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
int64_t
*
threshold_dim
,
cudnnDataType_t
dataType
,
at
::
Half
*
devPtrX
,
at
::
Half
*
devPtrW
,
at
::
Half
*
devPtrY
,
at
::
Half
*
devPtrZ
,
at
::
Half
*
devPtrB
,
at
::
Half
*
devPtrI
,
int
*
devPtrT
,
int
*
devPtrU
,
int
axis
)
{
cudnnHandle_t
handle_
=
torch
::
native
::
getCudnnHandle
();
std
::
stringstream
log_buf
;
try
{
int
convDim
=
2
;
// Creates the necessary tensor descriptors
masked_convbias_descriptors
tensors
=
create_conv_bias_add_act_mask_descriptors
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
y_dim_padded
,
threshold_dim
,
dataType
);
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
X_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
Y_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
W_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
Z_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
B_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERADD_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERBIAS_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
OPTIONAL
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERACT_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
GEN_INDEX_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
MASK_TOP_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
MASK_BOTTOM_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
MASK_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
THRESHOLD_TOP_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
THRESHOLD_BOTTOM_TENSOR
>
(
tensors
).
describe
());
// Define the add operation
auto
scaleDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_MUL
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scaleDesc
.
describe
());
// Define the bias operation
auto
biasDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_ADD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
biasDesc
.
describe
());
// optional add
auto
addDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_ADD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
addDesc
.
describe
());
// Define the activation operation
auto
actDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_RELU_FWD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
actDesc
.
describe
());
// Define the convolution problem
auto
convDesc
=
cudnn_frontend
::
ConvDescBuilder
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
setMathMode
(
CUDNN_CROSS_CORRELATION
)
.
setNDims
(
convDim
)
.
setStrides
(
convDim
,
convstride
)
.
setPrePadding
(
convDim
,
pad
)
.
setPostPadding
(
convDim
,
pad
)
.
setDilation
(
convDim
,
dilation
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
convDesc
.
describe
());
// Define the genIndex descriptor
auto
genIndexDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_GEN_INDEX
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
setAxis
(
axis
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
genIndexDesc
.
describe
());
// Define the lessThan descriptor
auto
lessThanDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_CMP_LT
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
lessThanDesc
.
describe
());
// Define the greaterThan descriptor
auto
greaterThanDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_CMP_GT
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
greaterThanDesc
.
describe
());
// Define the logical_or descriptor
auto
logicalOrDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_LOGICAL_OR
)
.
setMathPrecision
(
CUDNN_DATA_BOOLEAN
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
logicalOrDesc
.
describe
());
// Define the binary_selection descriptor
auto
selectionDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_BINARY_SELECT
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
selectionDesc
.
describe
());
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
// Create a convolution Node
auto
conv_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
X_TENSOR
>
(
tensors
))
.
setwDesc
(
std
::
get
<
W_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
))
.
setcDesc
(
convDesc
)
.
setAlpha
(
alpha
)
.
setBeta
(
beta
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
conv_op
.
describe
());
// Create a Add Node with scaling parameters.
auto
scale_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
conv_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
Z_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERADD_TENSOR
>
(
tensors
))
.
setpwDesc
(
scaleDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scale_op
.
describe
());
// Create a Bias Node.
auto
bias_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
scale_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
B_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERBIAS_TENSOR
>
(
tensors
))
.
setpwDesc
(
biasDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
bias_op
.
describe
());
// Create a optional add Node.
auto
add_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
bias_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
OPTIONAL
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTEROPT_TENSOR
>
(
tensors
))
.
setpwDesc
(
addDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
add_op
.
describe
());
// Create an Activation Node.
auto
act_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
devPtrI
?
add_op
.
getOutputTensor
()
:
bias_op
.
getOutputTensor
())
.
setyDesc
(
std
::
get
<
AFTERACT_TENSOR
>
(
tensors
))
.
setpwDesc
(
actDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
act_op
.
describe
());
// Create a Gen_Index Node.
auto
genIndex_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
AFTERACT_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
GEN_INDEX_TENSOR
>
(
tensors
))
.
setpwDesc
(
genIndexDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
genIndex_op
.
describe
());
// Create a LessThan Node.
auto
lessThan_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
GEN_INDEX_TENSOR
>
(
tensors
))
.
setbDesc
(
std
::
get
<
THRESHOLD_TOP_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
MASK_TOP_TENSOR
>
(
tensors
))
.
setpwDesc
(
lessThanDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
lessThan_op
.
describe
());
// Create a GreaterThan Node.
auto
greaterThan_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
GEN_INDEX_TENSOR
>
(
tensors
))
.
setbDesc
(
std
::
get
<
THRESHOLD_BOTTOM_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
MASK_BOTTOM_TENSOR
>
(
tensors
))
.
setpwDesc
(
greaterThanDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
greaterThan_op
.
describe
());
// Create a LogicalOr Node.
auto
logicalOr_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
MASK_TOP_TENSOR
>
(
tensors
))
.
setbDesc
(
std
::
get
<
MASK_BOTTOM_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
MASK_TENSOR
>
(
tensors
))
.
setpwDesc
(
logicalOrDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
logicalOr_op
.
describe
());
// Create a Binary_Selection Node.
auto
selection_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
))
.
setbDesc
(
std
::
get
<
AFTERACT_TENSOR
>
(
tensors
))
.
settDesc
(
std
::
get
<
MASK_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
Y_TENSOR
>
(
tensors
))
.
setpwDesc
(
selectionDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
selection_op
.
describe
());
// Create an Operation Graph. In this case it is convolution add bias activation
if
(
devPtrI
)
{
std
::
array
<
cudnn_frontend
::
Operation
const
*
,
10
>
ops
=
{
&
conv_op
,
&
scale_op
,
&
bias_op
,
&
add_op
,
&
act_op
,
&
genIndex_op
,
&
lessThan_op
,
&
greaterThan_op
,
&
logicalOr_op
,
&
selection_op
};
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
.
setHandle
(
handle_
)
.
setOperationGraph
(
ops
.
size
(),
ops
.
data
())
.
build
();
// Create string encoding for plan caching
auto
cache_string
=
getConvFusionString
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
dataType
,
opGraph
.
getTag
());
DEBUG_CUDNN_MSG
(
log_buf
,
"[convstring] "
<<
cache_string
);
auto
&
plan
=
getOrCreatePlan
(
handle_
,
log_buf
,
opGraph
,
cache_string
);
DEBUG_CUDNN_MSG
(
log_buf
,
"Plan tag: "
<<
plan
.
getTag
());
auto
workspace_size
=
plan
.
getWorkspaceSize
();
DEBUG_CUDNN_MSG
(
log_buf
,
plan
.
describe
()
<<
" requires workspace "
<<
workspace_size
);
void
*
workspace_ptr
=
nullptr
;
auto
workspace_tensor
=
at
::
empty
({(
workspace_size
+
3
)
/
4
},
at
::
TensorOptions
(
at
::
kCUDA
).
dtype
(
at
::
kFloat
));
if
(
workspace_size
>
0
)
{
workspace_ptr
=
workspace_tensor
.
data_ptr
<
float
>
();
}
void
*
data_ptrs
[]
=
{
devPtrX
,
devPtrY
,
devPtrW
,
devPtrZ
,
devPtrB
,
devPtrI
,
devPtrT
,
devPtrU
};
int64_t
uids
[]
=
{
'x'
,
'y'
,
'w'
,
'z'
,
'b'
,
'i'
,
't'
,
'u'
};
auto
variantPack
=
cudnn_frontend
::
VariantPackBuilder
()
.
setWorkspacePointer
(
workspace_ptr
)
.
setDataPointers
(
8
,
data_ptrs
)
.
setUids
(
8
,
uids
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
"variantPack "
<<
variantPack
.
describe
());
cudnnStatus_t
status
=
cudnnBackendExecute
(
handle_
,
plan
.
get_raw_desc
(),
variantPack
.
get_raw_desc
());
checkCudnnErr
(
status
);
cudnn_frontend
::
throw_if
([
status
]()
{
return
(
status
!=
CUDNN_STATUS_SUCCESS
);
},
"Plan execute error"
,
status
);
}
else
{
std
::
array
<
cudnn_frontend
::
Operation
const
*
,
9
>
ops
=
{
&
conv_op
,
&
scale_op
,
&
bias_op
,
&
act_op
,
&
genIndex_op
,
&
lessThan_op
,
&
greaterThan_op
,
&
logicalOr_op
,
&
selection_op
};
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
.
setHandle
(
handle_
)
.
setOperationGraph
(
ops
.
size
(),
ops
.
data
())
.
build
();
// Create string encoding for plan caching
auto
cache_string
=
getConvFusionString
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
dataType
,
opGraph
.
getTag
());
DEBUG_CUDNN_MSG
(
log_buf
,
"[convstring] "
<<
cache_string
);
auto
&
plan
=
getOrCreatePlan
(
handle_
,
log_buf
,
opGraph
,
cache_string
);
DEBUG_CUDNN_MSG
(
log_buf
,
"Plan tag: "
<<
plan
.
getTag
());
auto
workspace_size
=
plan
.
getWorkspaceSize
();
DEBUG_CUDNN_MSG
(
log_buf
,
plan
.
describe
()
<<
" requires workspace "
<<
workspace_size
);
void
*
workspace_ptr
=
nullptr
;
auto
workspace_tensor
=
at
::
empty
({(
workspace_size
+
3
)
/
4
},
at
::
TensorOptions
(
at
::
kCUDA
).
dtype
(
at
::
kFloat
));
if
(
workspace_size
>
0
)
{
workspace_ptr
=
workspace_tensor
.
data_ptr
<
float
>
();
}
void
*
data_ptrs
[]
=
{
devPtrX
,
devPtrY
,
devPtrW
,
devPtrZ
,
devPtrB
,
devPtrT
,
devPtrU
};
int64_t
uids
[]
=
{
'x'
,
'y'
,
'w'
,
'z'
,
'b'
,
't'
,
'u'
};
auto
variantPack
=
cudnn_frontend
::
VariantPackBuilder
()
.
setWorkspacePointer
(
workspace_ptr
)
.
setDataPointers
(
7
,
data_ptrs
)
.
setUids
(
7
,
uids
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
"variantPack "
<<
variantPack
.
describe
());
cudnnStatus_t
status
=
cudnnBackendExecute
(
handle_
,
plan
.
get_raw_desc
(),
variantPack
.
get_raw_desc
());
checkCudnnErr
(
status
);
cudnn_frontend
::
throw_if
([
status
]()
{
return
(
status
!=
CUDNN_STATUS_SUCCESS
);
},
"Plan execute error"
,
status
);
}
}
catch
(
cudnn_frontend
::
cudnnException
e
)
{
std
::
cout
<<
log_buf
.
str
()
<<
"[ERROR] Exception "
<<
e
.
what
()
<<
std
::
endl
;
}
}
void
run_conv_scale_bias
(
int64_t
*
x_dim_padded
,
int64_t
*
pad
,
...
...
@@ -1613,9 +2250,12 @@ struct bottleneck_forward_status {
int64_t
dimA
[
4
];
int64_t
filterdimA1
[
4
];
int64_t
filterdimA2
[
4
];
int64_t
filterdimA2hh
[
4
];
int64_t
filterdimA3
[
4
];
int64_t
filterdimA4
[
4
];
int64_t
threshdim
[
4
];
int
axis
[
4
];
int64_t
outdimA0
[
4
];
...
...
@@ -1643,8 +2283,10 @@ struct bottleneck_forward_status {
dimA
[
0
]
=
dimA
[
1
]
=
dimA
[
2
]
=
dimA
[
3
]
=
0
;
filterdimA1
[
0
]
=
filterdimA1
[
1
]
=
filterdimA1
[
2
]
=
filterdimA1
[
3
]
=
0
;
filterdimA2
[
0
]
=
filterdimA2
[
1
]
=
filterdimA2
[
2
]
=
filterdimA2
[
3
]
=
0
;
filterdimA2hh
[
0
]
=
filterdimA2hh
[
1
]
=
filterdimA2hh
[
2
]
=
filterdimA2hh
[
3
]
=
0
;
filterdimA3
[
0
]
=
filterdimA3
[
1
]
=
filterdimA3
[
2
]
=
filterdimA3
[
3
]
=
0
;
filterdimA4
[
0
]
=
filterdimA4
[
1
]
=
filterdimA4
[
2
]
=
filterdimA4
[
3
]
=
0
;
threshdim
[
0
]
=
threshdim
[
1
]
=
threshdim
[
2
]
=
threshdim
[
3
]
=
1
;
// All dim calculation after this order of n,c,h,w
if
(
explicit_nhwc
)
{
...
...
@@ -1670,6 +2312,13 @@ struct bottleneck_forward_status {
filterdimA4
[
dim
]
=
inputs
[
10
].
size
(
axis
[
dim
]);
}
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
if
(
dim
==
2
)
{
filterdimA2hh
[
dim
]
=
1
;
}
else
{
filterdimA2hh
[
dim
]
=
filterdimA2
[
dim
];
}
}
// output dim in n,c,h,w used by backend
outdimA0
[
0
]
=
outdimA0
[
1
]
=
outdimA0
[
2
]
=
outdimA0
[
3
]
=
0
;
...
...
@@ -1833,6 +2482,41 @@ at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_
return
halo_y2
;
}
// compute halo correction term (top or bottom) from slim halo input (N,C,1,W).
// slim halo input is 1 pixel wide in H.
at
::
Tensor
bottleneck_forward_out2_halo_corr
(
bool
explicit_nhwc
,
at
::
Tensor
slim_halo_y1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
at
::
Tensor
w1by3
,
at
::
Tensor
out2_part_halo
)
{
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// run
at
::
Half
*
w
=
w1by3
.
data_ptr
<
at
::
Half
>
();
// C,C,1,3
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
8
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
y1
=
slim_halo_y1
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
prev_out2
=
out2_part_halo
.
data_ptr
<
at
::
Half
>
();
auto
halo_y2
=
at
::
empty
(
forward_state
.
outdim4
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
y2
=
halo_y2
.
data_ptr
<
at
::
Half
>
();
run_conv_add_scale_bias_activation
(
forward_state
.
outdimA4
,
forward_state
.
padA2
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA2hh
,
forward_state
.
outdimA4
,
CUDNN_DATA_HALF
,
y1
,
w
,
y2
,
z
,
b
,
prev_out2
);
return
halo_y2
;
}
void
bottleneck_forward_out2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
std
::
cout
<<
std
::
fixed
;
...
...
@@ -1871,6 +2555,48 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
DEBUG_MSG
(
"[DEBUG] new relu2 : "
<<
out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
void
bottleneck_forward_out2_mask
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
thresholdTop
,
at
::
Tensor
thresholdBottom
)
{
std
::
cout
<<
std
::
fixed
;
// from _out1 method
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
auto
out1
=
outputs
[
0
];
at
::
Half
*
y1
=
out1
.
data_ptr
<
at
::
Half
>
();
// run
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
8
].
data_ptr
<
at
::
Half
>
();
auto
out2
=
outputs
[
1
];
at
::
Half
*
y2
=
out2
.
data_ptr
<
at
::
Half
>
();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation_mask
(
forward_state
.
outdimA1
,
forward_state
.
padA1
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA2
,
forward_state
.
outdimA2
,
forward_state
.
threshdim
,
CUDNN_DATA_HALF
,
y1
,
w
,
y2
,
z
,
b
,
nullptr
,
thresholdTop
.
data_ptr
<
int
>
(),
thresholdBottom
.
data_ptr
<
int
>
(),
2
);
// axis == 1 -> Does this assume explicit NHWC?
DEBUG_MSG
(
"[DEBUG] new relu2 : "
<<
out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
void
bottleneck_forward_out2_pad
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
out1_pad
)
{
std
::
cout
<<
std
::
fixed
;
...
...
@@ -2569,7 +3295,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"forward_init"
,
&
bottleneck_forward_init
,
"Bottleneck block init"
);
m
.
def
(
"forward_out1"
,
&
bottleneck_forward_out1
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2"
,
&
bottleneck_forward_out2
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_mask"
,
&
bottleneck_forward_out2_mask
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_halo"
,
&
bottleneck_forward_out2_halo
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_halo_corr"
,
&
bottleneck_forward_out2_halo_corr
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_pad"
,
&
bottleneck_forward_out2_pad
,
"Bottleneck block forward"
);
m
.
def
(
"forward_rest"
,
&
bottleneck_forward_rest
,
"Bottleneck block forward"
);
m
.
def
(
"backward_init"
,
&
bottleneck_backward_init
,
"Bottleneck block backward init"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment