Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e9c5efc4
Commit
e9c5efc4
authored
Aug 06, 2020
by
Chao Liu
Browse files
add bwd-data-v5r1-nhwc, refactored bwd-data-v4r1-nchw, remove obsolete kernels
parent
fe7b2d9f
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
297 additions
and
1342 deletions
+297
-1342
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+0
-268
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
+0
-388
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+27
-49
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp
...ution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp
+100
-141
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+43
-0
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+0
-257
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
+0
-196
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+32
-2
driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp
...ution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp
+86
-26
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+9
-15
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
fe7b2d9f
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
// GemmM = C * YTilda * XTilda;
// GemmN = N * HTildaSlice * WTildaSlice;
// GemmK = K * YDot * XDot;
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
typename
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmM
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
{
__device__
void
Run
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
const
{
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
#if 0 // debug
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
constexpr
index_t
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
constexpr
index_t
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
WTilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
constexpr
index_t
HTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
YTilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
WTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
XTilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
HTildaRight
=
math
::
min
(
HTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
WTildaRight
=
math
::
min
(
WTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
HTildaSlice
=
HTildaRight
-
HTildaLeft
;
constexpr
index_t
WTildaSlice
=
WTildaRight
-
WTildaLeft
;
// weight tensor
constexpr
auto
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
=
transform_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Embed
<
Y
,
Sequence
<
YDot
,
YTilda
>
,
Sequence
<
ConvStrideH
/
GcdStrideDilationH
,
1
,
0
>>
{},
Embed
<
X
,
Sequence
<
XDot
,
XTilda
>
,
Sequence
<
ConvStrideW
/
GcdStrideDilationW
,
1
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YDot
,
XDot
>>
{},
Merge
<
Sequence
<
C
,
YTilda
,
XTilda
>>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
constexpr
auto
out_n_k_ydot_htilda_xdot_wtilda_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
Embed
<
Ho
,
Sequence
<
YDot
,
HTilda
>
,
Sequence
<-
ConvDilationH
/
GcdStrideDilationH
,
1
,
0
>>
{},
Embed
<
Wo
,
Sequence
<
XDot
,
WTilda
>
,
Sequence
<-
ConvDilationW
/
GcdStrideDilationW
,
1
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htilda_xdot_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
YTilda
>
{},
PassThrough
<
XTilda
>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
HTildaLeft
,
WTildaLeft
>
,
Sequence
<
HTildaRight
,
WTildaRight
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
out_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YDot
,
XDot
>>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#if 1 // debug
constexpr
bool
in_skip_all_out_of_bound_check
=
false
;
#else
constexpr
bool
in_skip_all_out_of_bound_check
=
true
;
#endif
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
,
in_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
YTilda
,
HTilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>
,
in_skip_all_out_of_bound_check
>
{},
Embed
<
Wip
,
Sequence
<
XTilda
,
WTilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>
,
in_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
YTilda
>
{},
PassThrough
<
XTilda
>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
HTildaLeft
,
WTildaLeft
>
,
Sequence
<
HTildaRight
,
WTildaRight
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
YTilda
,
XTilda
>>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
out_gemmk_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
fe7b2d9f
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
// Number of GEMMs: YTilda * XTilda
// GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice
// GemmK = K * YDotSlice * XDotSlice
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
typename
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmM
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
{
// this is a hack, should query this info from gridwise_gemm instead of duplicate its logic
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
index_t
max_lds_align
=
math
::
lcm
(
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_gemmk_gemmm_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
GemmKPerBlock
,
GemmMPerBlock
>
{},
Number
<
max_lds_align
>
{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_gemmk_gemmn_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
GemmKPerBlock
,
GemmNPerBlock
>
{},
Number
<
max_lds_align
>
{});
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_gemmk_gemmm_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_gemmk_gemmn_block_desc
.
GetElementSpace
(),
max_lds_align
);
return
2
*
(
a_block_space
+
b_block_space
)
*
sizeof
(
Float
);
}
__device__
void
Run
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
const
{
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
#if 0 // debug
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
constexpr
index_t
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
constexpr
index_t
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
WTilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
constexpr
index_t
HTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
YTilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
WTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
XTilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
HTildaRight
=
math
::
min
(
HTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
WTildaRight
=
math
::
min
(
WTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
HTildaSlice
=
HTildaRight
-
HTildaLeft
;
constexpr
index_t
WTildaSlice
=
WTildaRight
-
WTildaLeft
;
constexpr
bool
wei_skip_all_out_of_bound_check
=
true
;
// weight tensor
constexpr
auto
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
=
transform_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Embed
<
Y
,
Sequence
<
YDot
,
YTilda
>
,
Sequence
<
ConvStrideH
/
GcdStrideDilationH
,
1
,
0
>
,
wei_skip_all_out_of_bound_check
>
{},
Embed
<
X
,
Sequence
<
XDot
,
XTilda
>
,
Sequence
<
ConvStrideW
/
GcdStrideDilationW
,
1
,
0
>
,
wei_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
#if 1 // debug
constexpr
bool
out_skip_all_out_of_bound_check
=
false
;
#else
constexpr
bool
out_skip_all_out_of_bound_check
=
true
;
#endif
// output tensor
constexpr
auto
out_n_k_ydot_htilda_xdot_wtilda_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
Embed
<
Ho
,
Sequence
<
YDot
,
HTilda
>
,
Sequence
<-
ConvDilationH
/
GcdStrideDilationH
,
1
,
0
>
,
out_skip_all_out_of_bound_check
>
{},
Embed
<
Wo
,
Sequence
<
XDot
,
WTilda
>
,
Sequence
<-
ConvDilationW
/
GcdStrideDilationW
,
1
,
0
>
,
out_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htilda_xdot_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
YTilda
>
{},
PassThrough
<
XTilda
>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
HTildaLeft
,
WTildaLeft
>
,
Sequence
<
HTildaRight
,
WTildaRight
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
#if 1 // debug
constexpr
bool
in_skip_all_out_of_bound_check
=
false
;
#else
constexpr
bool
in_skip_all_out_of_bound_check
=
true
;
#endif
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
,
in_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
YTilda
,
HTilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>
,
in_skip_all_out_of_bound_check
>
{},
Embed
<
Wip
,
Sequence
<
XTilda
,
WTilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>
,
in_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
YTilda
>
{},
PassThrough
<
XTilda
>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
HTildaLeft
,
WTildaLeft
>
,
Sequence
<
HTildaRight
,
WTildaRight
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
// GEMMs
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
Float
);
__shared__
Float
p_shared_block
[
shared_block_size
];
static_for
<
0
,
YTilda
,
1
>
{}([
&
](
auto
iYTilda_
)
{
static_for
<
0
,
XTilda
,
1
>
{}([
&
](
auto
iXTilda_
)
{
constexpr
index_t
iYTilda
=
decltype
(
iYTilda_
){};
constexpr
index_t
iXTilda
=
decltype
(
iXTilda_
){};
constexpr
index_t
YDotSlice
=
(
iYTilda
+
1
)
*
YDot
<=
Y
?
YDot
:
Y
%
YDot
;
constexpr
index_t
XDotSlice
=
(
iXTilda
+
1
)
*
XDot
<=
X
?
XDot
:
X
%
XDot
;
// A matrix
constexpr
auto
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{},
Slice
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>
,
Sequence
<
iYTilda
+
1
,
iXTilda
+
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YDotSlice
,
XDotSlice
>>
{},
Merge
<
Sequence
<
C
,
1
,
1
>>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B matrix
constexpr
auto
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
HTildaSlice
>
{},
PassThrough
<
WTildaSlice
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
out_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YDotSlice
,
XDotSlice
>>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// C matrix
constexpr
auto
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
HTildaSlice
>
{},
PassThrough
<
WTildaSlice
>
{},
Slice
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>
,
Sequence
<
iYTilda
+
1
,
iXTilda
+
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
1
,
1
>>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
out_gemmk_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
,
p_shared_block
);
// is synchronization necessary?
__syncthreads
();
});
});
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
e9c5efc4
...
@@ -217,23 +217,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -217,23 +217,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc
=
constexpr
auto
wei_k_c_ydotslice_xdotslice_global_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{},
Freeze
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>>
{}),
Slice
<
Sequence
<
YTilda
,
XTilda
>
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
Sequence
<
iYTilda
,
iXTilda
>
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<>
{}));
Sequence
<
iYTilda
+
1
,
iXTilda
+
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydotslice_
ytidaslice_xdotslice_xtilda
slice_global_desc
,
wei_k_c_ydotslice_
xdot
slice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YDotSlice
,
XDotSlice
>>
{},
Merge
<
Sequence
<
C
,
1
,
1
>
>
{}),
make_tuple
(
Merge
<
Sequence
<
K
,
YDotSlice
,
XDotSlice
>>
{},
PassThrough
<
C
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
3
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B matrix: output tensor
// B matrix: output tensor
...
@@ -265,8 +262,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -265,8 +262,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
out_n_k_ydot_htilda_xdot_wtilda_global_desc
,
out_n_k_ydot_htilda_xdot_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
PassThrough
<
Y
Tilda
>
{},
PassThrough
<
Y
Dot
>
{},
PassThrough
<
X
Tilda
>
{},
PassThrough
<
X
Dot
>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{}),
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{}),
...
@@ -331,40 +328,21 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -331,40 +328,21 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc
=
constexpr
auto
in_n_c_htildaslice_wtildaslice_global_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
,
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
Freeze
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>>
{},
PassThrough
<
YTilda
>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
PassThrough
<
XTilda
>
{},
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{}),
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
HTildaSlice
>
{},
PassThrough
<
WTildaSlice
>
{},
Slice
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>
,
Sequence
<
iYTilda
+
1
,
iXTilda
+
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_
ytildaslice_htildaslice_x
tildaslice_wtildaslice_global_desc
,
in_n_c_
h
tildaslice_wtildaslice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
1
,
1
>
>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
PassThrough
<
C
>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
gridwise_gemm
=
constexpr
auto
gridwise_gemm
=
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_n
c
hw_k
c
yx_n
k
hw.hpp
→
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhw
c
_kyx
c
_nhw
k
.hpp
View file @
e9c5efc4
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_N
C
HW_K
C
YX_N
K
HW_HPP
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHW
C
_KYX
C
_NHW
K
_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_N
C
HW_K
C
YX_N
K
HW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHW
C
_KYX
C
_NHW
K
_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
...
@@ -8,11 +8,12 @@
...
@@ -8,11 +8,12 @@
namespace
ck
{
namespace
ck
{
// Number of GEMM partition: YTilda * XTilda
// Number of GEMMs = YTilda * XTilda
// Number of GEMM iteration: YDotSlice * XDotSlice
// GemmM = C
// GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice
// GemmN = N * HTildaSlice * WTildaSlice
// GemmK0 = YDotSlice
// GemmK = K
// GemmK1 = XDotSlice
// GemmK2 = K
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
typename
Float
,
typename
Float
,
...
@@ -42,10 +43,10 @@ template <index_t GridSize,
...
@@ -42,10 +43,10 @@ template <index_t GridSize,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_Gemm
N
,
index_t
GemmBBlockCopySrcDataPerRead_Gemm
K2
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v5r1_n
c
hw_k
c
yx_n
k
hw
struct
GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhw
c
_kyx
c
_nhw
k
{
{
__host__
__device__
static
constexpr
index_t
GetNumberOfGemm
()
__host__
__device__
static
constexpr
index_t
GetNumberOfGemm
()
{
{
...
@@ -67,16 +68,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
...
@@ -67,16 +68,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
__host__
__device__
static
constexpr
auto
GetGemmSizeImpl
(
index_t
iYTilda
,
index_t
iXTilda
)
__host__
__device__
static
constexpr
auto
GetGemmSizeImpl
(
index_t
iYTilda
,
index_t
iXTilda
)
{
{
constexpr
index_t
N
=
InGlobalDesc
::
GetLengths
()[
0
];
constexpr
index_t
N
=
InGlobalDesc
::
GetLengths
()[
0
];
constexpr
index_t
C
=
InGlobalDesc
::
GetLengths
()[
1
];
constexpr
index_t
Hi
=
InGlobalDesc
::
GetLengths
()[
1
];
constexpr
index_t
H
i
=
InGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
W
i
=
InGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
Wi
=
InGlobalDesc
::
GetLengths
()[
3
];
constexpr
index_t
C
=
InGlobalDesc
::
GetLengths
()[
3
];
constexpr
index_t
K
=
OutGlobalDesc
::
GetLengths
()[
1
];
constexpr
index_t
Ho
=
OutGlobalDesc
::
GetLengths
()[
1
];
constexpr
index_t
H
o
=
OutGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
W
o
=
OutGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
Wo
=
OutGlobalDesc
::
GetLengths
()[
3
];
constexpr
index_t
K
=
OutGlobalDesc
::
GetLengths
()[
3
];
constexpr
index_t
Y
=
WeiGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
Y
=
WeiGlobalDesc
::
GetLengths
()[
1
];
constexpr
index_t
X
=
WeiGlobalDesc
::
GetLengths
()[
3
];
constexpr
index_t
X
=
WeiGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
...
@@ -120,9 +121,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
...
@@ -120,9 +121,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
index_t
YDotSlice
=
(
iYTilda
+
1
)
*
YDot
<=
Y
?
YDot
:
Y
%
YDot
;
index_t
YDotSlice
=
(
iYTilda
+
1
)
*
YDot
<=
Y
?
YDot
:
Y
%
YDot
;
index_t
XDotSlice
=
(
iXTilda
+
1
)
*
XDot
<=
X
?
XDot
:
X
%
XDot
;
index_t
XDotSlice
=
(
iXTilda
+
1
)
*
XDot
<=
X
?
XDot
:
X
%
XDot
;
index_t
GemmK
=
K
*
YDotSlice
*
XDotSlice
;
index_t
GemmK0
=
YDotSlice
;
index_t
GemmK1
=
XDotSlice
;
index_t
GemmK2
=
K
;
return
Array
<
index_t
,
3
>
{
GemmM
,
GemmN
,
GemmK
};
return
Array
<
index_t
,
5
>
{
GemmM
,
GemmN
,
GemmK
0
,
GemmK1
,
GemmK2
};
}
}
__host__
__device__
static
constexpr
auto
GetGemmSize
(
index_t
gemm_id
)
__host__
__device__
static
constexpr
auto
GetGemmSize
(
index_t
gemm_id
)
...
@@ -146,21 +149,21 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
...
@@ -146,21 +149,21 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
const
Float
*
__restrict__
p_out_global
)
{
{
constexpr
auto
in_n_
c_
hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
in_n_hi_wi_
c_
global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_
c_
y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
wei_k_y_x_
c_
global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_
k_
ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
auto
out_n_ho_wo_
k_
global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_
c_
hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
N
=
in_n_hi_wi_
c_
global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_
c_
hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_hi_wi_
c_
global_desc
.
GetLengths
()[
1
];
constexpr
index_t
H
i
=
in_n_
c_
hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
W
i
=
in_n_hi_wi_
c_
global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_
c_
hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
C
=
in_n_hi_wi_
c_
global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_
k_
ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_ho_wo_
k_
global_desc
.
GetLengths
()[
1
];
constexpr
index_t
H
o
=
out_n_
k_
ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
W
o
=
out_n_ho_wo_
k_
global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_
k_
ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_ho_wo_
k_
global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_
c_
y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Y
=
wei_k_y_x_
c_
global_desc
.
GetLengths
()[
1
];
constexpr
index_t
X
=
wei_k_
c_
y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
X
=
wei_k_y_x_
c_
global_desc
.
GetLengths
()[
2
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
...
@@ -203,10 +206,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
...
@@ -203,10 +206,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
// weight out-of-bound check can be skipped
// weight out-of-bound check can be skipped
constexpr
bool
wei_skip_out_of_bound_check
=
true
;
constexpr
bool
wei_skip_out_of_bound_check
=
true
;
constexpr
auto
wei_k_
c_
ydot_ytilda_xdot_xtilda_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
wei_k_ydot_ytilda_xdot_xtilda_
c_
global_desc
=
transform_tensor_descriptor
(
wei_k_
c_
y_x_global_desc
,
wei_k_y_x_
c_
global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Embed
<
Y
,
Embed
<
Y
,
Sequence
<
YDot
,
YTilda
>
,
Sequence
<
YDot
,
YTilda
>
,
Sequence
<
ConvStrideH
/
GcdStrideDilationH
,
1
,
0
>
,
Sequence
<
ConvStrideH
/
GcdStrideDilationH
,
1
,
0
>
,
...
@@ -214,31 +216,24 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
...
@@ -214,31 +216,24 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
Embed
<
X
,
Embed
<
X
,
Sequence
<
XDot
,
XTilda
>
,
Sequence
<
XDot
,
XTilda
>
,
Sequence
<
ConvStrideW
/
GcdStrideDilationW
,
1
,
0
>
,
Sequence
<
ConvStrideW
/
GcdStrideDilationW
,
1
,
0
>
,
wei_skip_out_of_bound_check
>
{}),
wei_skip_out_of_bound_check
>
{},
PassThrough
<
C
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc
=
constexpr
auto
wei_k_ydotslice_xdotslice_c_global_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
wei_k_ydot_ytilda_xdot_xtilda_c_global_desc
,
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{},
PassThrough
<
C
>
{},
Freeze
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{},
PassThrough
<
C
>
{}),
Slice
<
Sequence
<
YTilda
,
XTilda
>
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
5
>
{}),
Sequence
<
iYTilda
,
iXTilda
>
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<>
{},
Sequence
<
3
>
{}));
Sequence
<
iYTilda
+
1
,
iXTilda
+
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
constexpr
auto
wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc
=
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}));
reorder_tensor_descriptor_given_lower2upper
(
wei_k_ydotslice_xdotslice_c_global_desc
,
Sequence
<
2
,
0
,
1
,
3
>
{});
constexpr
auto
wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc
,
make_tuple
(
PassThrough
<
YDotSlice
>
{},
PassThrough
<
XDotSlice
>
{},
PassThrough
<
K
>
{},
Merge
<
Sequence
<
C
,
1
,
1
>>
{}),
make_tuple
(
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// B matrix: output tensor
// B matrix: output tensor
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
...
@@ -249,10 +244,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
...
@@ -249,10 +244,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
constexpr
bool
out_skip_out_of_bound_check
=
true
;
constexpr
bool
out_skip_out_of_bound_check
=
true
;
#endif
#endif
constexpr
auto
out_n_
k_
ydot_htilda_xdot_wtilda_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
out_n_ydot_htilda_xdot_wtilda_
k_
global_desc
=
transform_tensor_descriptor
(
out_n_
k_
ho_wo_global_desc
,
out_n_ho_wo_
k_
global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
Embed
<
Ho
,
Embed
<
Ho
,
Sequence
<
YDot
,
HTilda
>
,
Sequence
<
YDot
,
HTilda
>
,
Sequence
<-
ConvDilationH
/
GcdStrideDilationH
,
1
,
0
>
,
Sequence
<-
ConvDilationH
/
GcdStrideDilationH
,
1
,
0
>
,
...
@@ -260,46 +254,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
...
@@ -260,46 +254,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
Embed
<
Wo
,
Embed
<
Wo
,
Sequence
<
XDot
,
WTilda
>
,
Sequence
<
XDot
,
WTilda
>
,
Sequence
<-
ConvDilationW
/
GcdStrideDilationW
,
1
,
0
>
,
Sequence
<-
ConvDilationW
/
GcdStrideDilationW
,
1
,
0
>
,
out_skip_out_of_bound_check
>
{}),
out_skip_out_of_bound_check
>
{},
PassThrough
<
K
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htilda_xdot_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
YTilda
>
{},
PassThrough
<
XTilda
>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
out_n_
k_
ydotslice_htildaslice_xdotslice_wtildaslice_global_desc
=
constexpr
auto
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_
k_
global_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
out_n_
k_
ydot_htilda
slice
_xdot_wtilda
slice
_global_desc
,
out_n_ydot_htilda_xdot_wtilda
_k
_global_desc
,
make_tuple
(
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{},
PassThrough
<
HTildaSlice
>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
PassThrough
<
WTildaSlice
>
{},
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{}),
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{},
make_tuple
(
PassThrough
<
K
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
5
>
{}));
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
out_gemmk0_gemmk1_gemmk2_gemmn_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
out_gemmk0_gemmk1_gemmk2_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_
k_
ydotslice_htildaslice_xdotslice_wtildaslice_global_desc
,
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_
k_
global_desc
,
make_tuple
(
PassThrough
<
YDotSlice
>
{},
make_tuple
(
PassThrough
<
YDotSlice
>
{},
PassThrough
<
XDotSlice
>
{},
PassThrough
<
XDotSlice
>
{},
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// C matrix: input tensor
// C matrix: input tensor
...
@@ -310,22 +289,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
...
@@ -310,22 +289,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
constexpr
bool
in_skip_out_of_bound_check
=
true
;
constexpr
bool
in_skip_out_of_bound_check
=
true
;
#endif
#endif
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
in_n_hip_wip_c_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
in_n_hi_wi_c_global_desc
,
make_tuple
(
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
N
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
,
in_skip_out_of_bound_check
>
{},
PassThrough
<
C
>
{},
PassThrough
<
C
>
{}),
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
,
in_skip_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_
c_
hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Hip
=
in_n_hip_wip_
c_
global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Wip
=
in_n_
c_
hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Wip
=
in_n_hip_wip_
c_
global_desc
.
GetLengths
()[
2
];
constexpr
auto
in_n_
c_
ytilda_htilda_xtilda_wtilda_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
in_n_ytilda_htilda_xtilda_wtilda_
c_
global_desc
=
transform_tensor_descriptor
(
in_n_
c_
hip_wip_global_desc
,
in_n_hip_wip_
c_
global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Embed
<
Hip
,
Sequence
<
YTilda
,
HTilda
>
,
Sequence
<
YTilda
,
HTilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>
,
...
@@ -333,44 +310,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
...
@@ -333,44 +310,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
Embed
<
Wip
,
Embed
<
Wip
,
Sequence
<
XTilda
,
WTilda
>
,
Sequence
<
XTilda
,
WTilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>
,
in_skip_out_of_bound_check
>
{}),
in_skip_out_of_bound_check
>
{},
PassThrough
<
C
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
YTilda
>
{},
PassThrough
<
XTilda
>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc
=
constexpr
auto
in_n_htildaslice_wtildaslice_c_global_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
in_n_ytilda_htilda_xtilda_wtilda_c_global_desc
,
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
make_tuple
(
PassThrough
<
N
>
{},
Freeze
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>>
{},
PassThrough
<
C
>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
PassThrough
<
HTildaSlice
>
{},
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
PassThrough
<
WTildaSlice
>
{},
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{},
Slice
<
Sequence
<
YTilda
,
XTilda
>
,
PassThrough
<
C
>
{}),
Sequence
<
iYTilda
,
iXTilda
>
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
5
>
{}),
Sequence
<
iYTilda
+
1
,
iXTilda
+
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_
c_ytildaslice_htildaslice_x
tildaslice_wtildaslice_global_desc
,
in_n_
h
tildaslice_wtildaslice_
c_
global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
1
,
1
>
>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
PassThrough
<
C
>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// call GEMM
// call GEMM
...
@@ -404,12 +363,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
...
@@ -404,12 +363,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN
,
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
3
,
2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
3
,
2
>
,
3
,
2
,
GemmBBlockCopySrcDataPerRead_Gemm
N
,
GemmBBlockCopySrcDataPerRead_Gemm
K2
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
...
...
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
e9c5efc4
...
@@ -488,6 +488,49 @@ struct Embed
...
@@ -488,6 +488,49 @@ struct Embed
}
}
};
};
// LowerLengths: Sequence<...>
// LowerFreezePoint: Sequence<...>
template
<
typename
LowerLengths
,
typename
LowerFreezePoint
>
struct
Freeze
{
static
constexpr
index_t
nDimLow
=
LowerLengths
::
Size
();
static
constexpr
index_t
nDimUp
=
0
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
explicit
constexpr
Freeze
()
{
// TODO: sanity check: LowerFreezePoint should be within range of LowerLengths
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
0
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Sequence
<>
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
/*idx_up*/
)
{
return
to_array
(
LowerFreezePoint
{});
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
/* idx_up_diff */
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
/* idx_low_old */
)
{
return
make_zero_array
<
index_t
,
nDimLow
>
();
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
};
template
<
index_t
LowerLength
,
index_t
VectorSize
>
template
<
index_t
LowerLength
,
index_t
VectorSize
>
struct
Vectorize
struct
Vectorize
{
{
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
fe7b2d9f
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
namespace
launcher
{
using
namespace
ck
;
template
<
typename
T
,
typename
InDesc
,
typename
WeiDesc
,
typename
OutDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
(
InDesc
in_nchw_desc
,
Tensor
<
T
>&
in_nchw
,
WeiDesc
wei_kcyx_desc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
out_nkhw_desc
,
const
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
std
::
size_t
nrepeat
)
{
using
namespace
ck
;
constexpr
index_t
N
=
out_nkhw_desc
.
GetLengths
()[
0
];
constexpr
index_t
K
=
out_nkhw_desc
.
GetLengths
()[
1
];
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLengths
()[
3
];
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif 1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
4
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
4
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif 1
// BlockSize = 256, each thread hold 64 data
// for 1x1 weight, 8x8 input
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
4
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
4
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
4
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
4
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
4
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
#endif
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
constexpr
index_t
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
constexpr
index_t
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
WTilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
constexpr
index_t
HTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
YTilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
WTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
XTilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
HTildaRight
=
math
::
min
(
HTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
WTildaRight
=
math
::
min
(
WTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
HTildaSlice
=
HTildaRight
-
HTildaLeft
;
constexpr
index_t
WTildaSlice
=
WTildaRight
-
WTildaLeft
;
constexpr
index_t
GemmM
=
C
*
YTilda
*
XTilda
;
constexpr
index_t
GemmN
=
N
*
HTildaSlice
*
WTildaSlice
;
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
using
gridwise_conv_bwd_data
=
GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
launch_kernel
(
run_gridwise_operation
<
gridwise_conv_bwd_data
,
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
}
}
// namespace launcher
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
fe7b2d9f
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
namespace
launcher
{
using
namespace
ck
;
template
<
typename
T
,
typename
InDesc
,
typename
WeiDesc
,
typename
OutDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
(
InDesc
in_nchw_desc
,
Tensor
<
T
>&
in_nchw
,
WeiDesc
wei_kcyx_desc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
out_nkhw_desc
,
const
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
std
::
size_t
nrepeat
)
{
using
namespace
ck
;
constexpr
index_t
N
=
out_nkhw_desc
.
GetLengths
()[
0
];
constexpr
index_t
K
=
out_nkhw_desc
.
GetLengths
()[
1
];
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLengths
()[
3
];
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
constexpr
index_t
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
constexpr
index_t
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
WTilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
constexpr
index_t
HTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
YTilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
WTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
XTilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
HTildaRight
=
math
::
min
(
HTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
WTildaRight
=
math
::
min
(
WTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
HTildaSlice
=
HTildaRight
-
HTildaLeft
;
constexpr
index_t
WTildaSlice
=
WTildaRight
-
WTildaLeft
;
constexpr
index_t
GemmM
=
C
;
constexpr
index_t
GemmN
=
N
*
HTildaSlice
*
WTildaSlice
;
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
using
gridwise_conv_bwd_data
=
GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
launch_kernel
(
run_gridwise_operation
<
gridwise_conv_bwd_data
,
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
}
}
// namespace launcher
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
e9c5efc4
...
@@ -57,8 +57,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
...
@@ -57,8 +57,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if
1
#if
0
// BlockSize = 256,
each thread hold 64 data
//
cdata = 64,
BlockSize = 256,
128x128x8
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmMPerBlock = 128;
...
@@ -86,6 +86,36 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
...
@@ -86,6 +86,36 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif
1
// cdata = 64, BlockSize = 256, 128x128x16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
8
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
8
,
1
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
#endif
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v5r1_n
c
hw_k
c
yx_n
k
hw.hpp
→
driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhw
c
_kyx
c
_nhw
k
.hpp
View file @
e9c5efc4
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v5r1_n
c
hw_k
c
yx_n
k
hw.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v5r1_nhw
c
_kyx
c
_nhw
k
.hpp"
namespace
launcher
{
namespace
launcher
{
...
@@ -17,7 +17,7 @@ template <typename T,
...
@@ -17,7 +17,7 @@ template <typename T,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
device_convolution_backward_data_implicit_gemm_v5r1_n
c
hw_k
c
yx_n
k
hw
(
InDesc
in_nchw_desc
,
void
device_convolution_backward_data_implicit_gemm_v5r1_nhw
c
_kyx
c
_nhw
k
(
InDesc
in_nchw_desc
,
Tensor
<
T
>&
in_nchw
,
Tensor
<
T
>&
in_nchw
,
WeiDesc
wei_kcyx_desc
,
WeiDesc
wei_kcyx_desc
,
const
Tensor
<
T
>&
wei_kcyx
,
const
Tensor
<
T
>&
wei_kcyx
,
...
@@ -48,17 +48,41 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
...
@@ -48,17 +48,41 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
auto
in_nhwc_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
Hi
,
Wi
,
C
>
{});
constexpr
auto
wei_kyxc_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
K
,
Y
,
X
,
C
>
{});
constexpr
auto
out_nhwk_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
Ho
,
Wo
,
K
>
{});
Tensor
<
float
>
in_nhwc
(
make_HostTensorDescriptor
(
in_nhwc_desc
));
Tensor
<
float
>
wei_kyxc
(
make_HostTensorDescriptor
(
wei_kyxc_desc
));
Tensor
<
float
>
out_nhwk
(
make_HostTensorDescriptor
(
out_nhwk_desc
));
auto
f_nchw2nhwc
=
[
&
](
auto
n
,
auto
hi
,
auto
wi
,
auto
c
)
{
in_nhwc
(
n
,
hi
,
wi
,
c
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
};
auto
f_kcyx2kyxc
=
[
&
](
auto
k
,
auto
y
,
auto
x
,
auto
c
)
{
wei_kyxc
(
k
,
y
,
x
,
c
)
=
wei_kcyx
(
k
,
c
,
y
,
x
);
};
auto
f_nkhw2nhwk
=
[
&
](
auto
n
,
auto
ho
,
auto
wo
,
auto
k
)
{
out_nhwk
(
n
,
ho
,
wo
,
k
)
=
out_nkhw
(
n
,
k
,
ho
,
wo
);
};
make_ParallelTensorFunctor
(
f_nchw2nhwc
,
N
,
Hi
,
Wi
,
C
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f_kcyx2kyxc
,
K
,
Y
,
X
,
C
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f_nkhw2nhwk
,
N
,
Ho
,
Wo
,
K
)(
std
::
thread
::
hardware_concurrency
());
std
::
size_t
data_sz
=
sizeof
(
T
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_n
c
hw_device_buf
(
data_sz
*
in_n
c
hw
.
mDesc
.
GetElementSpace
());
DeviceMem
in_nhw
c
_device_buf
(
data_sz
*
in_nhw
c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k
c
yx_device_buf
(
data_sz
*
wei_k
c
yx
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kyx
c
_device_buf
(
data_sz
*
wei_kyx
c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n
k
hw_device_buf
(
data_sz
*
out_n
k
hw
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nhw
k
_device_buf
(
data_sz
*
out_nhw
k
.
mDesc
.
GetElementSpace
());
in_n
c
hw_device_buf
.
ToDevice
(
in_n
c
hw
.
mData
.
data
());
in_nhw
c
_device_buf
.
ToDevice
(
in_nhw
c
.
mData
.
data
());
wei_k
c
yx_device_buf
.
ToDevice
(
wei_k
c
yx
.
mData
.
data
());
wei_kyx
c
_device_buf
.
ToDevice
(
wei_kyx
c
.
mData
.
data
());
out_n
k
hw_device_buf
.
ToDevice
(
out_n
k
hw
.
mData
.
data
());
out_nhw
k
_device_buf
.
ToDevice
(
out_nhw
k
.
mData
.
data
());
#if
1
#if
0
// BlockSize = 256,
each thread hold 64 data
//
cdata = 64,
BlockSize = 256,
128x128x8
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmMPerBlock = 128;
...
@@ -74,16 +98,46 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
...
@@ -74,16 +98,46 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using
GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM
=
Sequence
<
1
,
1
,
4
,
1
>
;
using GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1,
1
,
4
>;
using
GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM
=
Sequence
<
1
,
1
,
2
,
128
>
;
using GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1,
8
,
32
>;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
1
;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM =
4
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM =
4
;
using GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 4, 1>;
using GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 4, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 2, 128>;
using GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 2, 128>;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK2 = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif
1
// cdata = 64, BlockSize = 256, 128x128x16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM
=
Sequence
<
1
,
1
,
2
,
4
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM
=
Sequence
<
1
,
1
,
8
,
32
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
4
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
4
;
using
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN
=
Sequence
<
1
,
1
,
8
,
1
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN
=
Sequence
<
1
,
1
,
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmK2
=
4
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
...
@@ -132,14 +186,14 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
...
@@ -132,14 +186,14 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
using
GridwiseConvBwdData
=
using
GridwiseConvBwdData
=
GridwiseConvolutionBackwardDataImplicitGemm_v5r1_n
c
hw_k
c
yx_n
k
hw
<
GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhw
c
_kyx
c
_nhw
k
<
GridSize
,
GridSize
,
BlockSize
,
BlockSize
,
T
,
T
,
T
,
T
,
decltype
(
in_n
c
hw_desc
),
decltype
(
in_nhw
c
_desc
),
decltype
(
wei_k
c
yx_desc
),
decltype
(
wei_kyx
c
_desc
),
decltype
(
out_n
k
hw_desc
),
decltype
(
out_nhw
k
_desc
),
ConvStrides
,
ConvStrides
,
ConvDilations
,
ConvDilations
,
InLeftPads
,
InLeftPads
,
...
@@ -162,14 +216,14 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
...
@@ -162,14 +216,14 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN
,
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN
,
GemmBBlockCopySrcDataPerRead_Gemm
N
,
GemmBBlockCopySrcDataPerRead_Gemm
K2
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
;
GemmCThreadCopyDstDataPerWrite_GemmN1
>
;
static_for
<
0
,
GridwiseConvBwdData
::
GetNumberOfGemm
(),
1
>
{}([
&
](
auto
gemm_id
)
{
static_for
<
0
,
GridwiseConvBwdData
::
GetNumberOfGemm
(),
1
>
{}([
&
](
auto
gemm_id
)
{
constexpr
auto
gemm_sizes
=
GridwiseConvBwdData
::
GetGemmSize
(
gemm_id
);
constexpr
auto
gemm_sizes
=
GridwiseConvBwdData
::
GetGemmSize
(
gemm_id
);
constexpr
index_t
gemm_k
=
gemm_sizes
.
At
(
2
);
constexpr
index_t
gemm_k
2
=
gemm_sizes
.
At
(
4
);
constexpr
bool
is_gemm_not_empty
=
gemm_k
>
0
;
constexpr
bool
is_gemm_not_empty
=
gemm_k
2
>
0
;
// only compile and run if GEMM is no empty
// only compile and run if GEMM is no empty
static_if
<
is_gemm_not_empty
>
{}([
&
](
auto
fwd
)
{
static_if
<
is_gemm_not_empty
>
{}([
&
](
auto
fwd
)
{
...
@@ -182,9 +236,9 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
...
@@ -182,9 +236,9 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
static_cast
<
T
*>
(
in_n
c
hw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_nhw
c
_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_k
c
yx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kyx
c
_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_n
k
hw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nhw
k
_device_buf
.
GetDeviceBuffer
()),
fwd
(
gemm_id
));
fwd
(
gemm_id
));
});
});
});
});
...
@@ -200,7 +254,13 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
...
@@ -200,7 +254,13 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc i
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
in_nhwc_device_buf
.
FromDevice
(
in_nhwc
.
mData
.
data
());
auto
f_nhwc2nchw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
in_nchw
(
n
,
c
,
hi
,
wi
)
=
in_nhwc
(
n
,
hi
,
wi
,
c
);
};
make_ParallelTensorFunctor
(
f_nhwc2nchw
,
N
,
C
,
Hi
,
Wi
)(
std
::
thread
::
hardware_concurrency
());
}
}
}
// namespace launcher
}
// namespace launcher
driver/src/conv_bwd_data_driver.cpp
View file @
e9c5efc4
...
@@ -15,10 +15,8 @@
...
@@ -15,10 +15,8 @@
#include "host_conv_bwd_data.hpp"
#include "host_conv_bwd_data.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v5r1_n
c
hw_k
c
yx_n
k
hw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v5r1_nhw
c
_kyx
c
_nhw
k
.hpp"
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -56,7 +54,7 @@ int main(int argc, char* argv[])
...
@@ -56,7 +54,7 @@ int main(int argc, char* argv[])
#elif 0
#elif 0
// 3x3, 28x28
// 3x3, 28x28
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
K
=
1024
;
...
@@ -161,7 +159,7 @@ int main(int argc, char* argv[])
...
@@ -161,7 +159,7 @@ int main(int argc, char* argv[])
#elif 0
#elif 0
// 1x7 filter, 0x3 pad, 17x17 input
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
K
=
1024
;
...
@@ -173,10 +171,10 @@ int main(int argc, char* argv[])
...
@@ -173,10 +171,10 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif
1
#elif
0
// 7x1 filter, 3x0 pad, 17x17 input
// 7x1 filter, 3x0 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
K
=
1024
;
...
@@ -188,13 +186,13 @@ int main(int argc, char* argv[])
...
@@ -188,13 +186,13 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif
0
#elif
1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
128
;
constexpr
index_t
K
=
128
0
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
...
@@ -251,13 +249,9 @@ int main(int argc, char* argv[])
...
@@ -251,13 +249,9 @@ int main(int argc, char* argv[])
#elif
0
#elif
0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 0
#elif 0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
#elif 1
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#elif 1
#elif 1
device_convolution_backward_data_implicit_gemm_v5r1_n
c
hw_k
c
yx_n
k
hw
device_convolution_backward_data_implicit_gemm_v5r1_nhw
c
_kyx
c
_nhw
k
#endif
#endif
(
in_nchw_desc
,
(
in_nchw_desc
,
in_nchw_device
,
in_nchw_device
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment