Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ea8aa63f
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "4c11b84b1c7360aecff7c4a679d7e05076ffc19d"
Commit
ea8aa63f
authored
Jan 14, 2020
by
Chao Liu
Browse files
adding bwd data v4r1 (multiple kernel launch)
parent
ea484457
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
687 additions
and
28 deletions
+687
-28
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
+4
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+333
-0
driver/include/device.hpp
driver/include/device.hpp
+50
-5
driver/include/device_col2im_eb_nchw.hpp
driver/include/device_col2im_eb_nchw.hpp
+2
-1
driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+3
-1
driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp
+2
-1
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+2
-1
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
+32
-1
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+231
-0
driver/include/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp
...r/include/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp
+7
-7
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
...de/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
+2
-1
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp
...ce_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp
+2
-1
driver/include/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
...de/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
+2
-1
driver/include/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
...de/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
+2
-1
driver/include/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
...de/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
+2
-1
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+2
-1
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp
...volution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp
+2
-1
driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp
+2
-1
driver/include/device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp
+2
-1
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+3
-1
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
View file @
ea8aa63f
...
@@ -195,7 +195,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
...
@@ -195,7 +195,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
#if 1 // debug
#if 1 // debug
constexpr
bool
in_skip_all_out_of_bound_check
=
false
;
constexpr
bool
in_skip_all_out_of_bound_check
=
false
;
#else
#else
constexpr
bool
in_skip_all_out_of_bound_check
=
true
;
constexpr
bool
in_skip_all_out_of_bound_check
=
true
;
#endif
#endif
// input tensor
// input tensor
...
@@ -381,6 +381,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
...
@@ -381,6 +381,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
,
p_shared_block
);
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
,
p_shared_block
);
// is synchronization necessary?
__syncthreads
();
});
});
});
});
}
}
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
ea8aa63f
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
// GemmM = C
// GemmN = N * Htilda * Wtilda;
// GemmK = K * YdotNonZero * XdotNonZero
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
Iter_ytilda
,
index_t
Iter_xtilda
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
typename
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmM
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
{
__device__
void
Run
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
const
{
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
#if 0 // debug
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
constexpr
index_t
hcf_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
hcf_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf_stride_dilation_h
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf_stride_dilation_w
;
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
constexpr
index_t
Htilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
Wtilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
constexpr
index_t
HtildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
Ytilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
WtildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
Xtilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
HtildaRight
=
math
::
min
(
Htilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
WtildaRight
=
math
::
min
(
Wtilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
HtildaTrim
=
HtildaRight
-
HtildaLeft
;
constexpr
index_t
WtildaTrim
=
WtildaRight
-
WtildaLeft
;
constexpr
bool
wei_skip_all_out_of_bound_check
=
true
;
// weight tensor
constexpr
auto
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
=
transform_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Embed
<
Y
,
Sequence
<
Ydot
,
Ytilda
>
,
Sequence
<
ConvStrideH
/
hcf_stride_dilation_h
,
1
,
0
>
,
wei_skip_all_out_of_bound_check
>
{},
Embed
<
X
,
Sequence
<
Xdot
,
Xtilda
>
,
Sequence
<
ConvStrideW
/
hcf_stride_dilation_w
,
1
,
0
>
,
wei_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
#if 1 // debug
constexpr
bool
out_skip_all_out_of_bound_check
=
false
;
#else
constexpr
bool
out_skip_all_out_of_bound_check
=
true
;
#endif
// output tensor
constexpr
auto
out_n_k_ydot_htilda_xdot_wtilda_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
Embed
<
Ho
,
Sequence
<
Ydot
,
Htilda
>
,
Sequence
<-
ConvDilationH
/
hcf_stride_dilation_h
,
1
,
0
>
,
out_skip_all_out_of_bound_check
>
{},
Embed
<
Wo
,
Sequence
<
Xdot
,
Wtilda
>
,
Sequence
<-
ConvDilationW
/
hcf_stride_dilation_w
,
1
,
0
>
,
out_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htilda_xdot_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Xtilda
>
{},
Trim
<
Sequence
<
Htilda
,
Wtilda
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
Htilda
-
HtildaRight
,
Wtilda
-
WtildaRight
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
#if 1 // debug
constexpr
bool
in_skip_all_out_of_bound_check
=
false
;
#else
constexpr
bool
in_skip_all_out_of_bound_check
=
true
;
#endif
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
,
in_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Ytilda
,
Htilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>
,
in_skip_all_out_of_bound_check
>
{},
Embed
<
Wip
,
Sequence
<
Xtilda
,
Wtilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>
,
in_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Xtilda
>
{},
Trim
<
Sequence
<
Htilda
,
Wtilda
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
Htilda
-
HtildaRight
,
Wtilda
-
WtildaRight
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
// GEMM
constexpr
index_t
ytilda
=
Iter_ytilda
;
constexpr
index_t
xtilda
=
Iter_xtilda
;
constexpr
index_t
YdotNonZero
=
(
ytilda
+
1
)
*
Ydot
<=
Y
?
Ydot
:
Y
%
Ydot
;
constexpr
index_t
XdotNonZero
=
(
xtilda
+
1
)
*
Xdot
<=
X
?
Xdot
:
X
%
Xdot
;
// A matrix
constexpr
auto
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
YdotNonZero
,
Xdot
-
XdotNonZero
>>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YdotNonZero
,
XdotNonZero
>>
{},
Merge
<
Sequence
<
C
,
1
,
1
>>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B matrix
constexpr
auto
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
YdotNonZero
,
Xdot
-
XdotNonZero
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
out_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YdotNonZero
,
XdotNonZero
>>
{},
Merge
<
Sequence
<
N
,
HtildaTrim
,
WtildaTrim
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// C matrix
constexpr
auto
in_n_c_1_htildatrim_1_wtildatrim_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_1_htildatrim_1_wtildatrim_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
1
,
1
>>
{},
Merge
<
Sequence
<
N
,
HtildaTrim
,
WtildaTrim
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
out_gemmk_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
none
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
);
}
};
}
// namespace ck
#endif
driver/include/device.hpp
View file @
ea8aa63f
...
@@ -30,33 +30,78 @@ struct KernelTimer
...
@@ -30,33 +30,78 @@ struct KernelTimer
std
::
unique_ptr
<
KernelTimerImpl
>
impl
;
std
::
unique_ptr
<
KernelTimerImpl
>
impl
;
};
};
#if CK_DEVICE_BACKEND_AMD
using
device_stream_t
=
hipStream_t
;
template
<
typename
...
Args
,
typename
F
>
void
launch_kernel
(
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
hipStream_t
stream_id
,
Args
...
args
)
{
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
}
template
<
typename
...
Args
,
typename
F
>
template
<
typename
...
Args
,
typename
F
>
float
launch_kernel
(
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
float
launch_and_time_kernel
(
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
hipStream_t
stream_id
,
Args
...
args
)
{
{
KernelTimer
timer
;
KernelTimer
timer
;
#if CK_DEVICE_BACKEND_AMD
timer
.
Start
();
timer
.
Start
();
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
0
,
args
...);
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
timer
.
End
();
timer
.
End
();
hipGetErrorString
(
hipGetLastError
());
hipGetErrorString
(
hipGetLastError
());
return
timer
.
GetElapsedTime
();
}
#elif CK_DEVICE_BACKEND_NVIDIA
#elif CK_DEVICE_BACKEND_NVIDIA
using
device_stream_t
=
cudaStream_t
;
template
<
typename
...
Args
,
typename
F
>
void
launch_kernel
(
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
cudaStream_t
stream_id
,
Args
...
args
)
{
cudaLaunchKernel
(
f
,
grid_dim
,
block_dim
,
p_args
,
lds_byte
,
stream_id
);
}
template
<
typename
...
Args
,
typename
F
>
float
launch_and_time_kernel
(
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
cudaStream_t
stream_id
,
Args
...
args
)
{
KernelTimer
timer
;
const
void
*
f
=
reinterpret_cast
<
const
void
*>
(
kernel
);
const
void
*
f
=
reinterpret_cast
<
const
void
*>
(
kernel
);
void
*
p_args
[]
=
{
&
args
...};
void
*
p_args
[]
=
{
&
args
...};
timer
.
Start
();
timer
.
Start
();
cudaError_t
error
=
cudaLaunchKernel
(
f
,
grid_dim
,
block_dim
,
p_args
,
lds_byte
,
0
);
cudaError_t
error
=
cudaLaunchKernel
(
f
,
grid_dim
,
block_dim
,
p_args
,
lds_byte
,
stream_id
);
timer
.
End
();
timer
.
End
();
checkCudaErrors
(
error
);
checkCudaErrors
(
error
);
#endif
return
timer
.
GetElapsedTime
();
return
timer
.
GetElapsedTime
();
}
}
#endif
#endif
#endif
driver/include/device_col2im_eb_nchw.hpp
View file @
ea8aa63f
...
@@ -88,7 +88,8 @@ void device_col2im_eb_nchw(ColDesc,
...
@@ -88,7 +88,8 @@ void device_col2im_eb_nchw(ColDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_col2im
),
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_col2im
),
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
T
*
const
__restrict__
>
,
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
View file @
ea8aa63f
...
@@ -121,13 +121,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
...
@@ -121,13 +121,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
gridwise_conv
,
gridwise_conv
,
const_cast
<
T
*
const
__restrict__
>
(
const_cast
<
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
())),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
())),
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp
View file @
ea8aa63f
...
@@ -129,7 +129,8 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
...
@@ -129,7 +129,8 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
const
T
*
const
__restrict__
>
,
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
ea8aa63f
...
@@ -217,7 +217,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
...
@@ -217,7 +217,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
const
T
*
const
__restrict__
>
,
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
View file @
ea8aa63f
...
@@ -84,6 +84,36 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
...
@@ -84,6 +84,36 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif 1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
1
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
#endif
...
@@ -156,7 +186,8 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
...
@@ -156,7 +186,8 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
const
T
*
const
__restrict__
>
,
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
ea8aa63f
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
template
<
typename
T
,
typename
InDesc
,
typename
WeiDesc
,
typename
OutDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
InDesc
in_nchw_desc
,
Tensor
<
T
>&
in_nchw
,
WeiDesc
wei_kcyx_desc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
out_nkhw_desc
,
const
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
std
::
size_t
nrepeat
)
{
using
namespace
ck
;
constexpr
index_t
N
=
out_nkhw_desc
.
GetLengths
()[
0
];
constexpr
index_t
K
=
out_nkhw_desc
.
GetLengths
()[
1
];
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLengths
()[
3
];
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif 1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
8
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
8
,
1
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
constexpr
index_t
hcf_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
hcf_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf_stride_dilation_h
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf_stride_dilation_w
;
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
constexpr
index_t
Htilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
Wtilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
constexpr
index_t
HtildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
Ytilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
WtildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
Xtilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
HtildaRight
=
math
::
min
(
Htilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
WtildaRight
=
math
::
min
(
Wtilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
HtildaTrim
=
HtildaRight
-
HtildaLeft
;
constexpr
index_t
WtildaTrim
=
WtildaRight
-
WtildaLeft
;
constexpr
index_t
GemmM
=
C
;
constexpr
index_t
GemmN
=
N
*
HtildaTrim
*
WtildaTrim
;
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
KernelTimer
timer
;
timer
.
Start
();
static_for
<
0
,
Ytilda
,
1
>
{}([
&
](
auto
ytilda_
)
{
static_for
<
0
,
Xtilda
,
1
>
{}([
&
](
auto
xtilda_
)
{
constexpr
index_t
ytilda
=
decltype
(
ytilda_
){};
constexpr
index_t
xtilda
=
decltype
(
xtilda_
){};
constexpr
auto
gridwise_conv
=
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
ytilda
,
xtilda
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
gridwise_conv
,
const_cast
<
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
())));
});
});
timer
.
End
();
float
time
=
timer
.
GetElapsedTime
();
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
}
driver/include/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp
View file @
ea8aa63f
...
@@ -82,13 +82,13 @@ void device_convolution_direct_v2_nchw_kcyx_nkhw(InDesc,
...
@@ -82,13 +82,13 @@ void device_convolution_direct_v2_nchw_kcyx_nkhw(InDesc,
WoPerThread
,
WoPerThread
,
InBlockCopyDataPerRead
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
>
;
WeiBlockCopyDataPerRead
>
;
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
gridwise_conv
,
T
>
,
float
time
=
launch_
and_time_
kernel
(
run_gridwise_convolution_kernel
<
gridwise_conv
,
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
printf
(
"Elapsed time : %f ms
\n
"
,
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
...
...
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
View file @
ea8aa63f
...
@@ -458,7 +458,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -458,7 +458,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
...
driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp
View file @
ea8aa63f
...
@@ -161,7 +161,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
...
@@ -161,7 +161,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
...
driver/include/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
View file @
ea8aa63f
...
@@ -354,7 +354,8 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
...
@@ -354,7 +354,8 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
WeiBlockCopyDataPerRead_K
,
WeiBlockCopyDataPerRead_K
,
OutThreadCopyDataPerWrite_W
>
{};
OutThreadCopyDataPerWrite_W
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
...
driver/include/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
View file @
ea8aa63f
...
@@ -306,7 +306,8 @@ void device_convolution_implicit_gemm_v2_chwn_cyxk_khwn(InDesc,
...
@@ -306,7 +306,8 @@ void device_convolution_implicit_gemm_v2_chwn_cyxk_khwn(InDesc,
WeiBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
OutThreadCopyDataPerWrite
>
{};
OutThreadCopyDataPerWrite
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
...
driver/include/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
View file @
ea8aa63f
...
@@ -135,7 +135,8 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
...
@@ -135,7 +135,8 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
WeiBlockCopyClusterLengths_C_K
,
WeiBlockCopyClusterLengths_C_K
,
WeiBlockCopyDataPerAccess_K
>
{};
WeiBlockCopyDataPerAccess_K
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
ea8aa63f
...
@@ -258,7 +258,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -258,7 +258,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
T
*
const
__restrict__
>
,
T
*
const
__restrict__
>
,
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp
View file @
ea8aa63f
...
@@ -247,7 +247,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
...
@@ -247,7 +247,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
...
driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp
View file @
ea8aa63f
...
@@ -200,7 +200,8 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc,
...
@@ -200,7 +200,8 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
{};
WeiBlockCopyDstDataPerWrite_K
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
...
driver/include/device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp
View file @
ea8aa63f
...
@@ -158,7 +158,8 @@ void device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw(InDesc,
...
@@ -158,7 +158,8 @@ void device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw(InDesc,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
{};
WeiBlockCopyDstDataPerWrite_K
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
...
...
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
ea8aa63f
...
@@ -225,10 +225,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -225,10 +225,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment