Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
60000f73
"src/turbomind/utils/vscode:/vscode.git/clone" did not exist on "1f88baa5b7a9dde22b11200fd530fe1059e1facb"
Commit
60000f73
authored
Mar 31, 2022
by
Thor Johnsen
Browse files
Add halo correction using new cudnn masking feature
parent
9c16d945
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
790 additions
and
34 deletions
+790
-34
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+56
-28
apex/contrib/csrc/bottleneck/bottleneck.cpp
apex/contrib/csrc/bottleneck/bottleneck.cpp
+734
-6
No files found.
apex/contrib/bottleneck/bottleneck.py
View file @
60000f73
...
@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module):
...
@@ -220,7 +220,7 @@ class Bottleneck(torch.nn.Module):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
class
SpatialBottleneckFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
,
explicit_nhwc
,
stride_1x1
,
scale
,
bias
,
x
,
*
conv
):
def
forward
(
ctx
,
spatial_group_size
,
spatial_group_rank
,
spatial_communicator
,
spatial_halo_exchanger
,
spatial_method
,
explicit_nhwc
,
stride_1x1
,
scale
,
bias
,
thresholdTop
,
thresholdBottom
,
x
,
*
conv
):
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
stream1
=
spatial_halo_exchanger
.
stream1
stream1
=
spatial_halo_exchanger
.
stream1
stream2
=
spatial_halo_exchanger
.
stream2
stream2
=
spatial_halo_exchanger
.
stream2
...
@@ -271,57 +271,75 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -271,57 +271,75 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if
spatial_group_rank
<
spatial_group_size
-
1
:
if
spatial_group_rank
<
spatial_group_size
-
1
:
stream2
.
wait_stream
(
stream1
)
stream2
.
wait_stream
(
stream1
)
with
torch
.
cuda
.
stream
(
stream2
):
with
torch
.
cuda
.
stream
(
stream2
):
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
if
explicit_nhwc
:
if
explicit_nhwc
:
btm_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
btm_fat_halo
[:,
2
:,:,:].
copy_
(
btm_out1_halo
)
else
:
else
:
btm_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
btm_fat_halo
[:,:,
0
:
2
,:].
copy_
(
out1
[:,:,
Hs
-
2
:,:])
btm_fat_halo
[:,:,
0
:
2
,:].
copy_
(
out1
[:,:,
Hs
-
2
:,:])
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_out1_halo
)
btm_fat_halo
[:,:,
2
:,:].
copy_
(
btm_out1_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
btm_fat_halo
,
args
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
btm_fat_halo
,
args
)
if
spatial_group_rank
>
0
:
if
spatial_group_rank
>
0
:
with
torch
.
cuda
.
stream
(
stream1
):
with
torch
.
cuda
.
stream
(
stream1
):
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
if
explicit_nhwc
:
if
explicit_nhwc
:
top_fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:
1
,:,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
else
:
else
:
top_fat_halo
=
torch
.
empty
((
N
,
C
,
3
,
W
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,:
1
,:].
copy_
(
top_out1_halo
)
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:,:
2
,:])
top_fat_halo
[:,:,
1
:
3
,:].
copy_
(
out1
[:,:,:
2
,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
explicit_nhwc
,
top_fat_halo
,
args
)
inc
.
add_delay
(
10
)
inc
.
add_delay
(
10
)
elif
spatial_method
==
2
:
elif
spatial_method
!=
2
and
spatial_method
!=
3
:
# wait for halo transfer to finish before doing a full convolution of padded x
assert
(
False
),
"spatial_method must be 1, 2 or 3"
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
fast_bottleneck
.
forward_out2_pad
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
else
:
assert
(
False
),
"spatial_method must be 1 or 2"
if
spatial_group_size
<=
1
or
spatial_method
==
1
:
if
spatial_group_size
<=
1
:
fast_bottleneck
.
forward_out2
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
elif
spatial_method
==
1
:
fast_bottleneck
.
forward_out2
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_out2
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
elif
spatial_method
==
2
:
# wait for halo transfer to finish before doing a full convolution of padded x
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
fast_bottleneck
.
forward_out2_pad
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
,
out1_pad
)
elif
spatial_method
==
3
:
fast_bottleneck
.
forward_out2_mask
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
,
thresholdTop
,
thresholdBottom
)
# compute halo cells for outputs[1] (out2)
# compute halo cells for outputs[1] (out2)
if
spatial_group_size
>
1
and
spatial_method
==
1
:
if
spatial_group_size
>
1
:
out2
=
outputs
[
1
]
out2
=
outputs
[
1
]
if
spatial_group_rank
>
0
:
if
explicit_nhwc
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
top_out2_halo
=
out2
[:,:
1
,:,:]
if
explicit_nhwc
:
btm_out2_halo
=
out2
[:,
Hs
-
1
:,:,:]
out2
[:,:
1
,:,:].
copy_
(
top_out2
)
else
:
else
:
top_out2_halo
=
out2
[:,:,:
1
,:]
out2
[:,:,:
1
,:].
copy_
(
top_out2
)
btm_out2_halo
=
out2
[:,:,
Hs
-
1
:,:]
if
spatial_group_rank
<
spatial_group_size
-
1
:
if
spatial_method
==
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
if
spatial_group_rank
>
0
:
if
explicit_nhwc
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
top_out2_halo
.
copy_
(
top_out2
)
else
:
if
spatial_group_rank
<
spatial_group_size
-
1
:
out2
[:,:,
Hs
-
1
:,:].
copy_
(
btm_out2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
btm_out2_halo
.
copy_
(
btm_out2
)
elif
spatial_method
==
3
:
if
spatial_group_rank
>
0
:
w1by3
=
args
[
2
][:,:,
2
:
3
,:].
contiguous
(
memory_format
=
torch
.
preserve
)
top_out1_halo
=
top_out1_halo
.
contiguous
(
memory_format
=
memory_format
)
top_out2
=
fast_bottleneck
.
forward_out2_halo_corr
(
explicit_nhwc
,
top_out1_halo
,
args
,
w1by3
,
top_out2_halo
.
contiguous
(
memory_format
=
memory_format
))
top_out2_halo
.
copy_
(
top_out2
)
if
spatial_group_rank
<
spatial_group_size
-
1
:
w1by3
=
args
[
2
][:,:,:
1
,:].
contiguous
(
memory_format
=
torch
.
preserve
)
btm_out1_halo
=
btm_out1_halo
.
contiguous
(
memory_format
=
memory_format
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo_corr
(
explicit_nhwc
,
btm_out1_halo
,
args
,
w1by3
,
btm_out2_halo
.
contiguous
(
memory_format
=
memory_format
))
btm_out2_halo
.
copy_
(
btm_out2
)
fast_bottleneck
.
forward_rest
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
fast_bottleneck
.
forward_rest
(
explicit_nhwc
,
stride_1x1
,
args
,
outputs
)
# save halos for backward pass
# save halos for backward pass
if
spatial_group_size
>
1
:
if
spatial_group_size
>
1
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
if
spatial_method
!=
2
:
torch
.
cuda
.
current_stream
().
wait_stream
(
stream3
)
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
out1_pad
,]))
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
out1_pad
,]))
else
:
else
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
...
@@ -460,7 +478,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -460,7 +478,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
fast_bottleneck
.
backward_rest
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
fast_bottleneck
.
backward_rest
(
ctx
.
explicit_nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
torch
.
cuda
.
current_stream
().
wait_stream
(
wgrad2_stream
)
torch
.
cuda
.
current_stream
().
wait_stream
(
wgrad2_stream
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
spatial_bottleneck_function
=
SpatialBottleneckFunction
.
apply
spatial_bottleneck_function
=
SpatialBottleneckFunction
.
apply
...
@@ -515,6 +533,8 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -515,6 +533,8 @@ class SpatialBottleneck(torch.nn.Module):
for
w
in
self
.
w_conv
:
for
w
in
self
.
w_conv
:
kaiming_uniform_
(
w
,
a
=
1
)
kaiming_uniform_
(
w
,
a
=
1
)
self
.
thresholdTop
,
self
.
thresholdBottom
=
None
,
None
# TODO: prevent unsupported case usage
# TODO: prevent unsupported case usage
# support cases
# support cases
# native cudnn
# native cudnn
...
@@ -536,6 +556,14 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -536,6 +556,14 @@ class SpatialBottleneck(torch.nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
self
.
use_cudnn
:
if
self
.
use_cudnn
:
if
self
.
thresholdTop
is
None
:
spatial_group_size
,
spatial_group_rank
,
_
,
_
,
_
=
self
.
spatial_parallel_args
if
self
.
explicit_nhwc
:
N
,
H
,
W
,
C
=
list
(
x
.
shape
)
else
:
N
,
C
,
H
,
W
=
list
(
x
.
shape
)
self
.
thresholdTop
=
torch
.
tensor
([
1
if
spatial_group_rank
>
0
else
0
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
self
.
thresholdBottom
=
torch
.
tensor
([
H
-
2
if
spatial_group_rank
<
spatial_group_size
-
1
else
H
-
1
],
dtype
=
torch
.
int32
,
device
=
'cuda'
)
# calculate scale/bias from registered buffers
# calculate scale/bias from registered buffers
# TODO: make this better
# TODO: make this better
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
s1
,
b1
=
self
.
bn1
.
get_scale_bias
(
self
.
explicit_nhwc
)
...
@@ -548,7 +576,7 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -548,7 +576,7 @@ class SpatialBottleneck(torch.nn.Module):
w_scale
.
append
(
s4
)
w_scale
.
append
(
s4
)
w_bias
.
append
(
b4
)
w_bias
.
append
(
b4
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_parallel_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
x
,
*
self
.
w_conv
)
out
=
spatial_bottleneck_function
(
*
self
.
spatial_parallel_args
,
self
.
explicit_nhwc
,
self
.
stride
,
w_scale
,
w_bias
,
self
.
thresholdTop
,
self
.
thresholdBottom
,
x
,
*
self
.
w_conv
)
return
out
return
out
if
self
.
explicit_nhwc
:
if
self
.
explicit_nhwc
:
...
...
apex/contrib/csrc/bottleneck/bottleneck.cpp
View file @
60000f73
...
@@ -102,6 +102,13 @@ enum {
...
@@ -102,6 +102,13 @@ enum {
AFTERCONV_TENSOR
,
AFTERCONV_TENSOR
,
OPTIONAL
,
OPTIONAL
,
AFTEROPT_TENSOR
,
AFTEROPT_TENSOR
,
AFTERACT_TENSOR
,
GEN_INDEX_TENSOR
,
MASK_TOP_TENSOR
,
MASK_BOTTOM_TENSOR
,
MASK_TENSOR
,
THRESHOLD_TOP_TENSOR
,
THRESHOLD_BOTTOM_TENSOR
,
};
};
using
common_conv_descriptors
=
using
common_conv_descriptors
=
...
@@ -173,11 +180,11 @@ using common_convbias_descriptors = std::tuple<cudnn_frontend::Tensor,
...
@@ -173,11 +180,11 @@ using common_convbias_descriptors = std::tuple<cudnn_frontend::Tensor,
common_convbias_descriptors
common_convbias_descriptors
create_conv_bias_add_act_descriptors
(
int64_t
*
x_dim_padded
,
create_conv_bias_add_act_descriptors
(
int64_t
*
x_dim_padded
,
int64_t
*
padA
,
int64_t
*
padA
,
int64_t
*
convstrideA
,
int64_t
*
convstrideA
,
int64_t
*
dilationA
,
int64_t
*
dilationA
,
int64_t
*
w_dim_padded
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
int64_t
*
y_dim_padded
,
cudnnDataType_t
dataType
)
{
cudnnDataType_t
dataType
)
{
const
int
convDim
=
2
;
const
int
convDim
=
2
;
int64_t
b_dim_padded
[
4
];
int64_t
b_dim_padded
[
4
];
...
@@ -190,6 +197,7 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded,
...
@@ -190,6 +197,7 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded,
int64_t
y_stride_padded
[
4
];
int64_t
y_stride_padded
[
4
];
int64_t
w_stride_padded
[
4
];
int64_t
w_stride_padded
[
4
];
int64_t
b_stride_padded
[
4
];
int64_t
b_stride_padded
[
4
];
int64_t
threshold_stride
[
4
];
generateStrides
(
w_dim_padded
,
w_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
w_dim_padded
,
w_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
x_dim_padded
,
x_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
x_dim_padded
,
x_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
...
@@ -272,6 +280,183 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded,
...
@@ -272,6 +280,183 @@ create_conv_bias_add_act_descriptors(int64_t* x_dim_padded,
.
build
());
.
build
());
}
}
using
masked_convbias_descriptors
=
std
::
tuple
<
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
,
cudnn_frontend
::
Tensor
>
;
masked_convbias_descriptors
create_conv_bias_add_act_mask_descriptors
(
int64_t
*
x_dim_padded
,
int64_t
*
padA
,
int64_t
*
convstrideA
,
int64_t
*
dilationA
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
int64_t
*
threshold_dim
,
cudnnDataType_t
dataType
)
{
const
int
convDim
=
2
;
int64_t
b_dim_padded
[
4
];
b_dim_padded
[
0
]
=
1
;
b_dim_padded
[
1
]
=
y_dim_padded
[
1
];
b_dim_padded
[
2
]
=
1
;
b_dim_padded
[
3
]
=
1
;
int64_t
x_stride_padded
[
4
];
int64_t
y_stride_padded
[
4
];
int64_t
w_stride_padded
[
4
];
int64_t
b_stride_padded
[
4
];
int64_t
threshold_stride
[
4
];
generateStrides
(
w_dim_padded
,
w_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
x_dim_padded
,
x_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
y_dim_padded
,
y_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
b_dim_padded
,
b_stride_padded
,
4
,
CUDNN_TENSOR_NHWC
);
generateStrides
(
threshold_dim
,
threshold_stride
,
4
,
CUDNN_TENSOR_NHWC
);
return
masked_convbias_descriptors
(
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
x_dim_padded
)
.
setStrides
(
4
,
x_stride_padded
)
.
setId
(
'x'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'y'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
w_dim_padded
)
.
setStrides
(
4
,
w_stride_padded
)
.
setId
(
'w'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
b_dim_padded
)
.
setStrides
(
4
,
b_stride_padded
)
.
setId
(
'z'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
b_dim_padded
)
.
setStrides
(
4
,
b_stride_padded
)
.
setId
(
'b'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setVirtual
()
.
setId
(
'A'
)
// after add
.
setAlignment
(
16
)
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setVirtual
()
.
setId
(
'B'
)
// after bias
.
setAlignment
(
16
)
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'C'
)
// after conv
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'i'
)
.
setAlignment
(
16
)
.
setDataType
(
dataType
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'D'
)
// after optional add
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'E'
)
// after act for masked
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'I'
)
// output of the gen index operation
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_INT32
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'm'
)
// top half of the mask created after the less than
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_BOOLEAN
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'n'
)
// bottom half of the mask
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_BOOLEAN
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
y_dim_padded
)
.
setStrides
(
4
,
y_stride_padded
)
.
setId
(
'M'
)
// OR of the top and bottom masks
.
setAlignment
(
16
)
.
setVirtual
()
.
setDataType
(
CUDNN_DATA_BOOLEAN
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
threshold_dim
)
.
setStrides
(
4
,
threshold_stride
)
.
setId
(
't'
)
// threshold for creating the top mask
.
setAlignment
(
16
)
.
setDataType
(
CUDNN_DATA_INT32
)
.
build
(),
cudnn_frontend
::
TensorBuilder
()
.
setDim
(
4
,
threshold_dim
)
.
setStrides
(
4
,
threshold_stride
)
.
setId
(
'u'
)
// threshold for creating the bottom mask
.
setAlignment
(
16
)
.
setDataType
(
CUDNN_DATA_INT32
)
.
build
());
}
// tensor descriptors used for dgrad
// tensor descriptors used for dgrad
enum
{
enum
{
X_OR_DX_TENSOR
,
X_OR_DX_TENSOR
,
...
@@ -593,7 +778,7 @@ run_conv_scale_bias_add_activation(int64_t* x_dim_padded,
...
@@ -593,7 +778,7 @@ run_conv_scale_bias_add_activation(int64_t* x_dim_padded,
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
.
setHandle
(
handle_
)
.
setHandle
(
handle_
)
.
setOperationGraph
(
devPtrI
?
ops
.
size
()
:
4
,
ops
.
data
())
.
setOperationGraph
(
devPtrI
?
ops
.
size
()
:
ops
.
size
()
-
1
,
ops
.
data
())
.
build
();
.
build
();
// Create string encoding for plan caching
// Create string encoding for plan caching
...
@@ -627,6 +812,458 @@ run_conv_scale_bias_add_activation(int64_t* x_dim_padded,
...
@@ -627,6 +812,458 @@ run_conv_scale_bias_add_activation(int64_t* x_dim_padded,
}
}
}
}
void
run_conv_add_scale_bias_activation
(
int64_t
*
x_dim_padded
,
int64_t
*
pad
,
int64_t
*
convstride
,
int64_t
*
dilation
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
cudnnDataType_t
dataType
,
at
::
Half
*
devPtrX
,
at
::
Half
*
devPtrW
,
at
::
Half
*
devPtrY
,
at
::
Half
*
devPtrZ
,
at
::
Half
*
devPtrB
,
at
::
Half
*
devPtrI
)
{
cudnnHandle_t
handle_
=
torch
::
native
::
getCudnnHandle
();
std
::
stringstream
log_buf
;
try
{
int
convDim
=
2
;
// Creates the necessary tensor descriptors
common_convbias_descriptors
tensors
=
create_conv_bias_add_act_descriptors
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
y_dim_padded
,
dataType
);
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
X_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
Y_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
W_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
Z_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
B_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERADD_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERBIAS_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
OPTIONAL
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTEROPT_TENSOR
>
(
tensors
).
describe
());
// Define the add operation
auto
scaleDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_MUL
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scaleDesc
.
describe
());
// Define the bias operation
auto
biasDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_ADD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
biasDesc
.
describe
());
// optional add
auto
addDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_ADD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
addDesc
.
describe
());
// Define the activation operation
auto
actDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_RELU_FWD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
actDesc
.
describe
());
// Define the convolution problem
auto
convDesc
=
cudnn_frontend
::
ConvDescBuilder
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
setMathMode
(
CUDNN_CROSS_CORRELATION
)
.
setNDims
(
convDim
)
.
setStrides
(
convDim
,
convstride
)
.
setPrePadding
(
convDim
,
pad
)
.
setPostPadding
(
convDim
,
pad
)
.
setDilation
(
convDim
,
dilation
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
convDesc
.
describe
());
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
// Create a convolution Node
auto
conv_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
X_TENSOR
>
(
tensors
))
.
setwDesc
(
std
::
get
<
W_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
))
.
setcDesc
(
convDesc
)
.
setAlpha
(
alpha
)
.
setBeta
(
beta
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
conv_op
.
describe
());
// create an add node.
auto
add_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
conv_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
OPTIONAL
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTEROPT_TENSOR
>
(
tensors
))
.
setpwDesc
(
addDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
add_op
.
describe
());
// Create a Add Node with scaling parameters.
auto
scale_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
add_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
Z_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERADD_TENSOR
>
(
tensors
))
.
setpwDesc
(
scaleDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scale_op
.
describe
());
// Create a Bias Node.
auto
bias_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
scale_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
B_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERBIAS_TENSOR
>
(
tensors
))
.
setpwDesc
(
biasDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
bias_op
.
describe
());
// Create an Activation Node.
auto
act_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
bias_op
.
getOutputTensor
())
.
setyDesc
(
std
::
get
<
Y_TENSOR
>
(
tensors
))
.
setpwDesc
(
actDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
act_op
.
describe
());
// Create an Operation Graph. In this case it is convolution add bias activation
std
::
array
<
cudnn_frontend
::
Operation
const
*
,
5
>
ops
=
{
&
conv_op
,
&
add_op
,
&
scale_op
,
&
bias_op
,
&
act_op
};
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
.
setHandle
(
handle_
)
.
setOperationGraph
(
ops
.
size
(),
ops
.
data
())
.
build
();
// Create string encoding for plan caching
auto
cache_string
=
getConvFusionString
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
dataType
,
opGraph
.
getTag
());
DEBUG_CUDNN_MSG
(
log_buf
,
"[convstring] "
<<
cache_string
);
auto
&
plan
=
getOrCreatePlan
(
handle_
,
log_buf
,
opGraph
,
cache_string
);
DEBUG_CUDNN_MSG
(
log_buf
,
"Plan tag: "
<<
plan
.
getTag
());
auto
workspace_size
=
plan
.
getWorkspaceSize
();
DEBUG_CUDNN_MSG
(
log_buf
,
plan
.
describe
()
<<
" requires workspace "
<<
workspace_size
);
void
*
workspace_ptr
=
nullptr
;
auto
workspace_tensor
=
at
::
empty
({(
workspace_size
+
3
)
/
4
},
at
::
TensorOptions
(
at
::
kCUDA
).
dtype
(
at
::
kFloat
));
if
(
workspace_size
>
0
)
{
workspace_ptr
=
workspace_tensor
.
data_ptr
<
float
>
();
}
void
*
data_ptrs
[]
=
{
devPtrX
,
devPtrY
,
devPtrW
,
devPtrZ
,
devPtrB
,
devPtrI
};
int64_t
uids
[]
=
{
'x'
,
'y'
,
'w'
,
'z'
,
'b'
,
'i'
};
auto
variantPack
=
cudnn_frontend
::
VariantPackBuilder
()
.
setWorkspacePointer
(
workspace_ptr
)
.
setDataPointers
(
6
,
data_ptrs
)
.
setUids
(
6
,
uids
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
"variantPack "
<<
variantPack
.
describe
());
cudnnStatus_t
status
=
cudnnBackendExecute
(
handle_
,
plan
.
get_raw_desc
(),
variantPack
.
get_raw_desc
());
checkCudnnErr
(
status
);
cudnn_frontend
::
throw_if
([
status
]()
{
return
(
status
!=
CUDNN_STATUS_SUCCESS
);
},
"Plan execute error"
,
status
);
}
catch
(
cudnn_frontend
::
cudnnException
e
)
{
std
::
cout
<<
log_buf
.
str
()
<<
"[ERROR] Exception "
<<
e
.
what
()
<<
std
::
endl
;
}
}
void
run_conv_scale_bias_add_activation_mask
(
int64_t
*
x_dim_padded
,
int64_t
*
pad
,
int64_t
*
convstride
,
int64_t
*
dilation
,
int64_t
*
w_dim_padded
,
int64_t
*
y_dim_padded
,
int64_t
*
threshold_dim
,
cudnnDataType_t
dataType
,
at
::
Half
*
devPtrX
,
at
::
Half
*
devPtrW
,
at
::
Half
*
devPtrY
,
at
::
Half
*
devPtrZ
,
at
::
Half
*
devPtrB
,
at
::
Half
*
devPtrI
,
int
*
devPtrT
,
int
*
devPtrU
,
int
axis
)
{
cudnnHandle_t
handle_
=
torch
::
native
::
getCudnnHandle
();
std
::
stringstream
log_buf
;
try
{
int
convDim
=
2
;
// Creates the necessary tensor descriptors
masked_convbias_descriptors
tensors
=
create_conv_bias_add_act_mask_descriptors
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
y_dim_padded
,
threshold_dim
,
dataType
);
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
X_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
Y_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
W_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
Z_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
B_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERADD_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERBIAS_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
OPTIONAL
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
AFTERACT_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
GEN_INDEX_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
MASK_TOP_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
MASK_BOTTOM_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
MASK_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
THRESHOLD_TOP_TENSOR
>
(
tensors
).
describe
());
DEBUG_CUDNN_MSG
(
log_buf
,
std
::
get
<
THRESHOLD_BOTTOM_TENSOR
>
(
tensors
).
describe
());
// Define the add operation
auto
scaleDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_MUL
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scaleDesc
.
describe
());
// Define the bias operation
auto
biasDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_ADD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
biasDesc
.
describe
());
// optional add
auto
addDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_ADD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
addDesc
.
describe
());
// Define the activation operation
auto
actDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_RELU_FWD
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
actDesc
.
describe
());
// Define the convolution problem
auto
convDesc
=
cudnn_frontend
::
ConvDescBuilder
()
.
setDataType
(
CUDNN_DATA_FLOAT
)
.
setMathMode
(
CUDNN_CROSS_CORRELATION
)
.
setNDims
(
convDim
)
.
setStrides
(
convDim
,
convstride
)
.
setPrePadding
(
convDim
,
pad
)
.
setPostPadding
(
convDim
,
pad
)
.
setDilation
(
convDim
,
dilation
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
convDesc
.
describe
());
// Define the genIndex descriptor
auto
genIndexDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_GEN_INDEX
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
setAxis
(
axis
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
genIndexDesc
.
describe
());
// Define the lessThan descriptor
auto
lessThanDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_CMP_LT
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
lessThanDesc
.
describe
());
// Define the greaterThan descriptor
auto
greaterThanDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_CMP_GT
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
greaterThanDesc
.
describe
());
// Define the logical_or descriptor
auto
logicalOrDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_LOGICAL_OR
)
.
setMathPrecision
(
CUDNN_DATA_BOOLEAN
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
logicalOrDesc
.
describe
());
// Define the binary_selection descriptor
auto
selectionDesc
=
cudnn_frontend
::
PointWiseDescBuilder
()
.
setMode
(
CUDNN_POINTWISE_BINARY_SELECT
)
.
setMathPrecision
(
CUDNN_DATA_FLOAT
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
selectionDesc
.
describe
());
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
// Create a convolution Node
auto
conv_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
X_TENSOR
>
(
tensors
))
.
setwDesc
(
std
::
get
<
W_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
))
.
setcDesc
(
convDesc
)
.
setAlpha
(
alpha
)
.
setBeta
(
beta
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
conv_op
.
describe
());
// Create a Add Node with scaling parameters.
auto
scale_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
conv_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
Z_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERADD_TENSOR
>
(
tensors
))
.
setpwDesc
(
scaleDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
scale_op
.
describe
());
// Create a Bias Node.
auto
bias_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
scale_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
B_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTERBIAS_TENSOR
>
(
tensors
))
.
setpwDesc
(
biasDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
bias_op
.
describe
());
// Create a optional add Node.
auto
add_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
bias_op
.
getOutputTensor
())
.
setbDesc
(
std
::
get
<
OPTIONAL
>
(
tensors
))
.
setyDesc
(
std
::
get
<
AFTEROPT_TENSOR
>
(
tensors
))
.
setpwDesc
(
addDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
add_op
.
describe
());
// Create an Activation Node.
auto
act_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
devPtrI
?
add_op
.
getOutputTensor
()
:
bias_op
.
getOutputTensor
())
.
setyDesc
(
std
::
get
<
AFTERACT_TENSOR
>
(
tensors
))
.
setpwDesc
(
actDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
act_op
.
describe
());
// Create a Gen_Index Node.
auto
genIndex_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
AFTERACT_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
GEN_INDEX_TENSOR
>
(
tensors
))
.
setpwDesc
(
genIndexDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
genIndex_op
.
describe
());
// Create a LessThan Node.
auto
lessThan_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
GEN_INDEX_TENSOR
>
(
tensors
))
.
setbDesc
(
std
::
get
<
THRESHOLD_TOP_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
MASK_TOP_TENSOR
>
(
tensors
))
.
setpwDesc
(
lessThanDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
lessThan_op
.
describe
());
// Create a GreaterThan Node.
auto
greaterThan_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
GEN_INDEX_TENSOR
>
(
tensors
))
.
setbDesc
(
std
::
get
<
THRESHOLD_BOTTOM_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
MASK_BOTTOM_TENSOR
>
(
tensors
))
.
setpwDesc
(
greaterThanDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
greaterThan_op
.
describe
());
// Create a LogicalOr Node.
auto
logicalOr_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
MASK_TOP_TENSOR
>
(
tensors
))
.
setbDesc
(
std
::
get
<
MASK_BOTTOM_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
MASK_TENSOR
>
(
tensors
))
.
setpwDesc
(
logicalOrDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
logicalOr_op
.
describe
());
// Create a Binary_Selection Node.
auto
selection_op
=
cudnn_frontend
::
OperationBuilder
(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR
)
.
setxDesc
(
std
::
get
<
AFTERCONV_TENSOR
>
(
tensors
))
.
setbDesc
(
std
::
get
<
AFTERACT_TENSOR
>
(
tensors
))
.
settDesc
(
std
::
get
<
MASK_TENSOR
>
(
tensors
))
.
setyDesc
(
std
::
get
<
Y_TENSOR
>
(
tensors
))
.
setpwDesc
(
selectionDesc
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
selection_op
.
describe
());
// Create an Operation Graph. In this case it is convolution add bias activation
if
(
devPtrI
)
{
std
::
array
<
cudnn_frontend
::
Operation
const
*
,
10
>
ops
=
{
&
conv_op
,
&
scale_op
,
&
bias_op
,
&
add_op
,
&
act_op
,
&
genIndex_op
,
&
lessThan_op
,
&
greaterThan_op
,
&
logicalOr_op
,
&
selection_op
};
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
.
setHandle
(
handle_
)
.
setOperationGraph
(
ops
.
size
(),
ops
.
data
())
.
build
();
// Create string encoding for plan caching
auto
cache_string
=
getConvFusionString
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
dataType
,
opGraph
.
getTag
());
DEBUG_CUDNN_MSG
(
log_buf
,
"[convstring] "
<<
cache_string
);
auto
&
plan
=
getOrCreatePlan
(
handle_
,
log_buf
,
opGraph
,
cache_string
);
DEBUG_CUDNN_MSG
(
log_buf
,
"Plan tag: "
<<
plan
.
getTag
());
auto
workspace_size
=
plan
.
getWorkspaceSize
();
DEBUG_CUDNN_MSG
(
log_buf
,
plan
.
describe
()
<<
" requires workspace "
<<
workspace_size
);
void
*
workspace_ptr
=
nullptr
;
auto
workspace_tensor
=
at
::
empty
({(
workspace_size
+
3
)
/
4
},
at
::
TensorOptions
(
at
::
kCUDA
).
dtype
(
at
::
kFloat
));
if
(
workspace_size
>
0
)
{
workspace_ptr
=
workspace_tensor
.
data_ptr
<
float
>
();
}
void
*
data_ptrs
[]
=
{
devPtrX
,
devPtrY
,
devPtrW
,
devPtrZ
,
devPtrB
,
devPtrI
,
devPtrT
,
devPtrU
};
int64_t
uids
[]
=
{
'x'
,
'y'
,
'w'
,
'z'
,
'b'
,
'i'
,
't'
,
'u'
};
auto
variantPack
=
cudnn_frontend
::
VariantPackBuilder
()
.
setWorkspacePointer
(
workspace_ptr
)
.
setDataPointers
(
8
,
data_ptrs
)
.
setUids
(
8
,
uids
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
"variantPack "
<<
variantPack
.
describe
());
cudnnStatus_t
status
=
cudnnBackendExecute
(
handle_
,
plan
.
get_raw_desc
(),
variantPack
.
get_raw_desc
());
checkCudnnErr
(
status
);
cudnn_frontend
::
throw_if
([
status
]()
{
return
(
status
!=
CUDNN_STATUS_SUCCESS
);
},
"Plan execute error"
,
status
);
}
else
{
std
::
array
<
cudnn_frontend
::
Operation
const
*
,
9
>
ops
=
{
&
conv_op
,
&
scale_op
,
&
bias_op
,
&
act_op
,
&
genIndex_op
,
&
lessThan_op
,
&
greaterThan_op
,
&
logicalOr_op
,
&
selection_op
};
auto
opGraph
=
cudnn_frontend
::
OperationGraphBuilder
()
.
setHandle
(
handle_
)
.
setOperationGraph
(
ops
.
size
(),
ops
.
data
())
.
build
();
// Create string encoding for plan caching
auto
cache_string
=
getConvFusionString
(
x_dim_padded
,
pad
,
convstride
,
dilation
,
w_dim_padded
,
dataType
,
opGraph
.
getTag
());
DEBUG_CUDNN_MSG
(
log_buf
,
"[convstring] "
<<
cache_string
);
auto
&
plan
=
getOrCreatePlan
(
handle_
,
log_buf
,
opGraph
,
cache_string
);
DEBUG_CUDNN_MSG
(
log_buf
,
"Plan tag: "
<<
plan
.
getTag
());
auto
workspace_size
=
plan
.
getWorkspaceSize
();
DEBUG_CUDNN_MSG
(
log_buf
,
plan
.
describe
()
<<
" requires workspace "
<<
workspace_size
);
void
*
workspace_ptr
=
nullptr
;
auto
workspace_tensor
=
at
::
empty
({(
workspace_size
+
3
)
/
4
},
at
::
TensorOptions
(
at
::
kCUDA
).
dtype
(
at
::
kFloat
));
if
(
workspace_size
>
0
)
{
workspace_ptr
=
workspace_tensor
.
data_ptr
<
float
>
();
}
void
*
data_ptrs
[]
=
{
devPtrX
,
devPtrY
,
devPtrW
,
devPtrZ
,
devPtrB
,
devPtrT
,
devPtrU
};
int64_t
uids
[]
=
{
'x'
,
'y'
,
'w'
,
'z'
,
'b'
,
't'
,
'u'
};
auto
variantPack
=
cudnn_frontend
::
VariantPackBuilder
()
.
setWorkspacePointer
(
workspace_ptr
)
.
setDataPointers
(
7
,
data_ptrs
)
.
setUids
(
7
,
uids
)
.
build
();
DEBUG_CUDNN_MSG
(
log_buf
,
"variantPack "
<<
variantPack
.
describe
());
cudnnStatus_t
status
=
cudnnBackendExecute
(
handle_
,
plan
.
get_raw_desc
(),
variantPack
.
get_raw_desc
());
checkCudnnErr
(
status
);
cudnn_frontend
::
throw_if
([
status
]()
{
return
(
status
!=
CUDNN_STATUS_SUCCESS
);
},
"Plan execute error"
,
status
);
}
}
catch
(
cudnn_frontend
::
cudnnException
e
)
{
std
::
cout
<<
log_buf
.
str
()
<<
"[ERROR] Exception "
<<
e
.
what
()
<<
std
::
endl
;
}
}
void
void
run_conv_scale_bias
(
int64_t
*
x_dim_padded
,
run_conv_scale_bias
(
int64_t
*
x_dim_padded
,
int64_t
*
pad
,
int64_t
*
pad
,
...
@@ -1613,9 +2250,12 @@ struct bottleneck_forward_status {
...
@@ -1613,9 +2250,12 @@ struct bottleneck_forward_status {
int64_t
dimA
[
4
];
int64_t
dimA
[
4
];
int64_t
filterdimA1
[
4
];
int64_t
filterdimA1
[
4
];
int64_t
filterdimA2
[
4
];
int64_t
filterdimA2
[
4
];
int64_t
filterdimA2hh
[
4
];
int64_t
filterdimA3
[
4
];
int64_t
filterdimA3
[
4
];
int64_t
filterdimA4
[
4
];
int64_t
filterdimA4
[
4
];
int64_t
threshdim
[
4
];
int
axis
[
4
];
int
axis
[
4
];
int64_t
outdimA0
[
4
];
int64_t
outdimA0
[
4
];
...
@@ -1643,8 +2283,10 @@ struct bottleneck_forward_status {
...
@@ -1643,8 +2283,10 @@ struct bottleneck_forward_status {
dimA
[
0
]
=
dimA
[
1
]
=
dimA
[
2
]
=
dimA
[
3
]
=
0
;
dimA
[
0
]
=
dimA
[
1
]
=
dimA
[
2
]
=
dimA
[
3
]
=
0
;
filterdimA1
[
0
]
=
filterdimA1
[
1
]
=
filterdimA1
[
2
]
=
filterdimA1
[
3
]
=
0
;
filterdimA1
[
0
]
=
filterdimA1
[
1
]
=
filterdimA1
[
2
]
=
filterdimA1
[
3
]
=
0
;
filterdimA2
[
0
]
=
filterdimA2
[
1
]
=
filterdimA2
[
2
]
=
filterdimA2
[
3
]
=
0
;
filterdimA2
[
0
]
=
filterdimA2
[
1
]
=
filterdimA2
[
2
]
=
filterdimA2
[
3
]
=
0
;
filterdimA2hh
[
0
]
=
filterdimA2hh
[
1
]
=
filterdimA2hh
[
2
]
=
filterdimA2hh
[
3
]
=
0
;
filterdimA3
[
0
]
=
filterdimA3
[
1
]
=
filterdimA3
[
2
]
=
filterdimA3
[
3
]
=
0
;
filterdimA3
[
0
]
=
filterdimA3
[
1
]
=
filterdimA3
[
2
]
=
filterdimA3
[
3
]
=
0
;
filterdimA4
[
0
]
=
filterdimA4
[
1
]
=
filterdimA4
[
2
]
=
filterdimA4
[
3
]
=
0
;
filterdimA4
[
0
]
=
filterdimA4
[
1
]
=
filterdimA4
[
2
]
=
filterdimA4
[
3
]
=
0
;
threshdim
[
0
]
=
threshdim
[
1
]
=
threshdim
[
2
]
=
threshdim
[
3
]
=
1
;
// All dim calculation after this order of n,c,h,w
// All dim calculation after this order of n,c,h,w
if
(
explicit_nhwc
)
{
if
(
explicit_nhwc
)
{
...
@@ -1670,6 +2312,13 @@ struct bottleneck_forward_status {
...
@@ -1670,6 +2312,13 @@ struct bottleneck_forward_status {
filterdimA4
[
dim
]
=
inputs
[
10
].
size
(
axis
[
dim
]);
filterdimA4
[
dim
]
=
inputs
[
10
].
size
(
axis
[
dim
]);
}
}
}
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
if
(
dim
==
2
)
{
filterdimA2hh
[
dim
]
=
1
;
}
else
{
filterdimA2hh
[
dim
]
=
filterdimA2
[
dim
];
}
}
// output dim in n,c,h,w used by backend
// output dim in n,c,h,w used by backend
outdimA0
[
0
]
=
outdimA0
[
1
]
=
outdimA0
[
2
]
=
outdimA0
[
3
]
=
0
;
outdimA0
[
0
]
=
outdimA0
[
1
]
=
outdimA0
[
2
]
=
outdimA0
[
3
]
=
0
;
...
@@ -1833,6 +2482,41 @@ at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_
...
@@ -1833,6 +2482,41 @@ at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_
return
halo_y2
;
return
halo_y2
;
}
}
// compute halo correction term (top or bottom) from slim halo input (N,C,1,W).
// slim halo input is 1 pixel wide in H.
at
::
Tensor
bottleneck_forward_out2_halo_corr
(
bool
explicit_nhwc
,
at
::
Tensor
slim_halo_y1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
at
::
Tensor
w1by3
,
at
::
Tensor
out2_part_halo
)
{
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// run
at
::
Half
*
w
=
w1by3
.
data_ptr
<
at
::
Half
>
();
// C,C,1,3
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
8
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
y1
=
slim_halo_y1
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
prev_out2
=
out2_part_halo
.
data_ptr
<
at
::
Half
>
();
auto
halo_y2
=
at
::
empty
(
forward_state
.
outdim4
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
y2
=
halo_y2
.
data_ptr
<
at
::
Half
>
();
run_conv_add_scale_bias_activation
(
forward_state
.
outdimA4
,
forward_state
.
padA2
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA2hh
,
forward_state
.
outdimA4
,
CUDNN_DATA_HALF
,
y1
,
w
,
y2
,
z
,
b
,
prev_out2
);
return
halo_y2
;
}
void
bottleneck_forward_out2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
void
bottleneck_forward_out2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
)
{
std
::
cout
<<
std
::
fixed
;
std
::
cout
<<
std
::
fixed
;
...
@@ -1871,6 +2555,48 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
...
@@ -1871,6 +2555,48 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
DEBUG_MSG
(
"[DEBUG] new relu2 : "
<<
out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
DEBUG_MSG
(
"[DEBUG] new relu2 : "
<<
out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
}
void
bottleneck_forward_out2_mask
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
thresholdTop
,
at
::
Tensor
thresholdBottom
)
{
std
::
cout
<<
std
::
fixed
;
// from _out1 method
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
auto
out1
=
outputs
[
0
];
at
::
Half
*
y1
=
out1
.
data_ptr
<
at
::
Half
>
();
// run
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
5
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
b
=
inputs
[
8
].
data_ptr
<
at
::
Half
>
();
auto
out2
=
outputs
[
1
];
at
::
Half
*
y2
=
out2
.
data_ptr
<
at
::
Half
>
();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation_mask
(
forward_state
.
outdimA1
,
forward_state
.
padA1
,
forward_state
.
convstrideA
,
forward_state
.
dilationA
,
forward_state
.
filterdimA2
,
forward_state
.
outdimA2
,
forward_state
.
threshdim
,
CUDNN_DATA_HALF
,
y1
,
w
,
y2
,
z
,
b
,
nullptr
,
thresholdTop
.
data_ptr
<
int
>
(),
thresholdBottom
.
data_ptr
<
int
>
(),
2
);
// axis == 1 -> Does this assume explicit NHWC?
DEBUG_MSG
(
"[DEBUG] new relu2 : "
<<
out2
.
to
(
at
::
kFloat
).
sum
().
item
<
float
>
());
}
void
bottleneck_forward_out2_pad
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
out1_pad
)
{
void
bottleneck_forward_out2_pad
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
out1_pad
)
{
std
::
cout
<<
std
::
fixed
;
std
::
cout
<<
std
::
fixed
;
...
@@ -2569,7 +3295,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -2569,7 +3295,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"forward_init"
,
&
bottleneck_forward_init
,
"Bottleneck block init"
);
m
.
def
(
"forward_init"
,
&
bottleneck_forward_init
,
"Bottleneck block init"
);
m
.
def
(
"forward_out1"
,
&
bottleneck_forward_out1
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out1"
,
&
bottleneck_forward_out1
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2"
,
&
bottleneck_forward_out2
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2"
,
&
bottleneck_forward_out2
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_mask"
,
&
bottleneck_forward_out2_mask
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_halo"
,
&
bottleneck_forward_out2_halo
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_halo"
,
&
bottleneck_forward_out2_halo
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_halo_corr"
,
&
bottleneck_forward_out2_halo_corr
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_pad"
,
&
bottleneck_forward_out2_pad
,
"Bottleneck block forward"
);
m
.
def
(
"forward_out2_pad"
,
&
bottleneck_forward_out2_pad
,
"Bottleneck block forward"
);
m
.
def
(
"forward_rest"
,
&
bottleneck_forward_rest
,
"Bottleneck block forward"
);
m
.
def
(
"forward_rest"
,
&
bottleneck_forward_rest
,
"Bottleneck block forward"
);
m
.
def
(
"backward_init"
,
&
bottleneck_backward_init
,
"Bottleneck block backward init"
);
m
.
def
(
"backward_init"
,
&
bottleneck_backward_init
,
"Bottleneck block backward init"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment