Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
84239246
Commit
84239246
authored
Aug 03, 2020
by
Chao Liu
Browse files
add bwd_data-v5r1
parent
6b165b9b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1056 additions
and
6 deletions
+1056
-6
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+447
-0
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+395
-0
driver/CMakeLists.txt
driver/CMakeLists.txt
+1
-2
driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+206
-0
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+7
-4
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
84239246
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
// Number of GEMM partition: YTilda * XTilda
// Number of GEMM iteration: YDotSlice * XDotSlice
// GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice
// GemmK = K
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
ThreadGemmDataPerRead_GemmM
,
index_t
ThreadGemmDataPerRead_GemmN
,
typename
GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM
,
typename
GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmM
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
{
__host__
__device__
static
constexpr
index_t
GetNumberOfGemm
()
{
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
return
YTilda
*
XTilda
;
}
__host__
__device__
static
constexpr
auto
GetGemmSizeImpl
(
index_t
iYTilda
,
index_t
iXTilda
)
{
constexpr
index_t
N
=
InGlobalDesc
::
GetLengths
()[
0
];
constexpr
index_t
C
=
InGlobalDesc
::
GetLengths
()[
1
];
constexpr
index_t
Hi
=
InGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
Wi
=
InGlobalDesc
::
GetLengths
()[
3
];
constexpr
index_t
K
=
OutGlobalDesc
::
GetLengths
()[
1
];
constexpr
index_t
Ho
=
OutGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
Wo
=
OutGlobalDesc
::
GetLengths
()[
3
];
constexpr
index_t
Y
=
WeiGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
X
=
WeiGlobalDesc
::
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
constexpr
index_t
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
constexpr
index_t
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
WTilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
constexpr
index_t
iHTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
YTilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
iWTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
XTilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
iHTildaRight
=
math
::
min
(
HTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
iWTildaRight
=
math
::
min
(
WTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
HTildaSlice
=
iHTildaRight
-
iHTildaLeft
;
constexpr
index_t
WTildaSlice
=
iWTildaRight
-
iWTildaLeft
;
// GemmM and GemmN
constexpr
index_t
GemmM
=
C
;
constexpr
index_t
GemmN
=
N
*
HTildaSlice
*
WTildaSlice
;
// GemmK is different for each GEMM
index_t
YDotSlice
=
(
iYTilda
+
1
)
*
YDot
<=
Y
?
YDot
:
Y
%
YDot
;
index_t
XDotSlice
=
(
iXTilda
+
1
)
*
XDot
<=
X
?
XDot
:
X
%
XDot
;
index_t
GemmK
=
K
*
YDotSlice
*
XDotSlice
;
return
Array
<
index_t
,
3
>
{
GemmM
,
GemmN
,
GemmK
};
}
__host__
__device__
static
constexpr
auto
GetGemmSize
(
index_t
gemm_id
)
{
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
index_t
iYTilda
=
gemm_id
/
XTilda
;
index_t
iXTilda
=
gemm_id
%
XTilda
;
return
GetGemmSizeImpl
(
iYTilda
,
iXTilda
);
}
template
<
index_t
iYTilda
,
index_t
iXTilda
>
__device__
static
void
RunImpl
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
{
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
constexpr
index_t
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
constexpr
index_t
YDotSlice
=
(
iYTilda
+
1
)
*
YDot
<=
Y
?
YDot
:
Y
%
YDot
;
constexpr
index_t
XDotSlice
=
(
iXTilda
+
1
)
*
XDot
<=
X
?
XDot
:
X
%
XDot
;
constexpr
index_t
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
WTilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
constexpr
index_t
iHTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
YTilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
iWTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
XTilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
iHTildaRight
=
math
::
min
(
HTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
iWTildaRight
=
math
::
min
(
WTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
HTildaSlice
=
iHTildaRight
-
iHTildaLeft
;
constexpr
index_t
WTildaSlice
=
iWTildaRight
-
iWTildaLeft
;
// A matrix: weight
// weight out-of-bound check can be skipped
constexpr
bool
wei_skip_out_of_bound_check
=
true
;
constexpr
auto
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
=
transform_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Embed
<
Y
,
Sequence
<
YDot
,
YTilda
>
,
Sequence
<
ConvStrideH
/
GcdStrideDilationH
,
1
,
0
>
,
wei_skip_out_of_bound_check
>
{},
Embed
<
X
,
Sequence
<
XDot
,
XTilda
>
,
Sequence
<
ConvStrideW
/
GcdStrideDilationW
,
1
,
0
>
,
wei_skip_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{},
Slice
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>
,
Sequence
<
iYTilda
+
1
,
iXTilda
+
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc
,
make_tuple
(
PassThrough
<
YDotSlice
>
{},
PassThrough
<
XDotSlice
>
{},
PassThrough
<
K
>
{},
Merge
<
Sequence
<
C
,
1
,
1
>>
{}),
make_tuple
(
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// B matrix: output tensor
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
// situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr
bool
out_skip_out_of_bound_check
=
false
;
#else
constexpr
bool
out_skip_out_of_bound_check
=
true
;
#endif
constexpr
auto
out_n_k_ydot_htilda_xdot_wtilda_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
Embed
<
Ho
,
Sequence
<
YDot
,
HTilda
>
,
Sequence
<-
ConvDilationH
/
GcdStrideDilationH
,
1
,
0
>
,
out_skip_out_of_bound_check
>
{},
Embed
<
Wo
,
Sequence
<
XDot
,
WTilda
>
,
Sequence
<-
ConvDilationW
/
GcdStrideDilationW
,
1
,
0
>
,
out_skip_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htilda_xdot_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
YTilda
>
{},
PassThrough
<
XTilda
>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
HTildaSlice
>
{},
PassThrough
<
WTildaSlice
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
out_gemmk0_gemmk1_gemmk2_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc
,
make_tuple
(
PassThrough
<
YDotSlice
>
{},
PassThrough
<
XDotSlice
>
{},
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// C matrix: input tensor
// TODO sometimes input out-of-bound check can be skipped, find out all such situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr
bool
in_skip_out_of_bound_check
=
false
;
#else
constexpr
bool
in_skip_out_of_bound_check
=
true
;
#endif
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
,
in_skip_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
YTilda
,
HTilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>
,
in_skip_out_of_bound_check
>
{},
Embed
<
Wip
,
Sequence
<
XTilda
,
WTilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>
,
in_skip_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
YTilda
>
{},
PassThrough
<
XTilda
>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
HTildaSlice
>
{},
PassThrough
<
WTildaSlice
>
{},
Slice
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>
,
Sequence
<
iYTilda
+
1
,
iXTilda
+
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
1
,
1
>>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// call GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v2
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc
),
decltype
(
out_gemmk0_gemmk1_gemmk2_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
ThreadGemmDataPerRead_GemmM
,
ThreadGemmDataPerRead_GemmN
,
GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
);
}
template
<
index_t
GemmId
>
__device__
static
void
Run
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
,
Number
<
GemmId
>
)
{
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
iYTilda
=
GemmId
/
XTilda
;
constexpr
index_t
iXTilda
=
GemmId
%
XTilda
;
static_assert
(
iYTilda
<
YTilda
&&
iXTilda
<
XTilda
,
"wrong! iYtilda, iXtilda"
);
RunImpl
<
iYTilda
,
iXTilda
>
(
p_in_global
,
p_wei_global
,
p_out_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
View file @
84239246
...
...
@@ -376,5 +376,400 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
}
};
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
ThreadGemmAThreadCopySrcDataPerRead_M
,
index_t
ThreadGemmBThreadCopySrcDataPerRead_N
,
typename
ABlockCopyThreadSliceLengths_K0_K1_K2_M
,
typename
ABlockCopyThreadClusterLengths_K0_K1_K2_M
,
typename
ABlockCopyThreadClusterArrangeOrder
,
typename
ABlockCopySrcAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_M
,
typename
BBlockCopyThreadSliceLengths_K0_K1_K2_N
,
typename
BBlockCopyThreadClusterLengths_K0_K1_K2_N
,
typename
BBlockCopyThreadClusterArrangeOrder
,
typename
BBlockCopySrcAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_N
,
typename
CThreadCopySrcDstAccessOrder
,
index_t
CThreadCopySrcDstVectorReadWriteDim
,
index_t
CThreadCopyDstDataPerWrite
>
struct
GridwiseGemmTransposedANormalBNormalC_v2
{
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDstDataPerWrite_M
,
BBlockCopyDstDataPerWrite_N
,
ThreadGemmAThreadCopySrcDataPerRead_M
,
ThreadGemmBThreadCopySrcDataPerRead_N
);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
MPerBlock
>
{},
Number
<
max_lds_align
>
{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
return
2
*
(
a_block_space
+
b_block_space
)
*
sizeof
(
Float
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_b_global
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
a_k0_k1_k2_m_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_k0_k1_k2_n_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
K0
=
a_k0_k1_k2_m_global_desc
.
GetLengths
()[
0
];
constexpr
auto
K1
=
a_k0_k1_k2_m_global_desc
.
GetLengths
()[
1
];
constexpr
auto
K
=
a_k0_k1_k2_m_global_desc
.
GetLengths
()[
2
];
constexpr
auto
M
=
c_m_n_global_desc
.
GetLengths
()[
0
];
constexpr
auto
N
=
c_m_n_global_desc
.
GetLengths
()[
1
];
// don't do anything if K == 0
if
(
K
==
0
)
{
return
;
}
// lds max alignment
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDstDataPerWrite_M
,
BBlockCopyDstDataPerWrite_N
,
ThreadGemmAThreadCopySrcDataPerRead_M
,
ThreadGemmBThreadCopySrcDataPerRead_N
);
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
MBlockWork
,
NBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
m_block_data_on_global
=
block_work_id
[
0
]
*
MPerBlock
;
const
index_t
n_block_data_on_global
=
block_work_id
[
1
]
*
NPerBlock
;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k0_k1_k2_m_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
1
,
KPerBlock
,
MPerBlock
>
{},
Number
<
max_lds_align
>
{});
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_k0_k1_k2_m_global_desc
),
decltype
(
a_k0_k1_k2_m_block_desc
),
decltype
(
a_k0_k1_k2_m_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_K0_K1_K2_M
,
ABlockCopyThreadClusterLengths_K0_K1_K2_M
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
ABlockCopySrcVectorReadDim
,
3
,
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_M
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
{
0
,
0
,
0
,
m_block_data_on_global
},
{
0
,
0
,
0
,
0
});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k0_k1_k2_n_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
1
,
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_k0_k1_k2_n_global_desc
),
decltype
(
b_k0_k1_k2_n_block_desc
),
decltype
(
b_k0_k1_k2_n_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_K0_K1_K2_N
,
BBlockCopyThreadClusterLengths_K0_K1_K2_N
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
BBlockCopySrcVectorReadDim
,
3
,
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_N
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
{
0
,
0
,
0
,
n_block_data_on_global
},
{
0
,
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
unfold_tensor_descriptor
(
a_k0_k1_k2_m_block_desc
,
I0
,
I2
));
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
unfold_tensor_descriptor
(
b_k0_k1_k2_n_block_desc
,
I0
,
I2
));
// sanity check
static_assert
(
MPerBlock
%
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
NPerBlock
%
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
MPerBlock
/
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
NPerBlock
/
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_m0m1_n0n1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
MPerThread
>
{},
Number
<
GemmNRepeat
*
NPerThread
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
decltype
(
c_m0m1_n0n1_thread_mtx_desc
),
MPerThread
,
NPerThread
,
KPerThread
,
MLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
ThreadGemmAThreadCopySrcDataPerRead_M
,
ThreadGemmBThreadCopySrcDataPerRead_N
>
{};
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_k0_k1_k2_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_k0_k1_k2_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
Float
*
p_a_block_double
=
p_shared_block
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space
;
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_mtx_desc
,
p_c_thread
);
for
(
index_t
k0
=
0
;
k0
<
K0
;
++
k0
)
{
for
(
index_t
k1
=
0
;
k1
<
K1
;
++
k1
)
{
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
constexpr
auto
a_block_slice_copy_steps
=
Sequence
<
0
,
0
,
KPerBlock
,
0
>
{};
constexpr
auto
b_block_slice_copy_steps
=
Sequence
<
0
,
0
,
KPerBlock
,
0
>
{};
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
;
Float
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
;
Float
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space
:
p_a_block_double
;
Float
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_steps
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_steps
,
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_steps
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_steps
,
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space
,
p_b_block_double
+
b_block_space
,
p_c_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
}
}
// reset slice windoww on K2 dimension, then move forward on K1 dimension
a_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
0
,
K
-
KPerBlock
,
0
>
{},
False
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
0
,
K
-
KPerBlock
,
0
>
{},
False
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
1
,
0
,
0
>
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
1
,
0
,
0
>
{},
True
);
}
// reset slice windoww on K1 dimension, then move forward on K0 dimension
a_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
K1
,
0
,
0
>
{},
False
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
K1
,
0
,
0
>
{},
False
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
1
,
0
,
0
,
0
>
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
1
,
0
,
0
,
0
>
{},
True
);
}
// input: register to global memory
{
constexpr
index_t
M1
=
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
M0
=
M
/
M1
;
constexpr
index_t
N1
=
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
N0
=
N
/
N1
;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr
auto
c_m0_m1_n0_n1_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
MPerThread
,
GemmNRepeat
,
NPerThread
>
{});
constexpr
auto
c_m0_m1_n0_n1_global_desc
=
transform_tensor_descriptor
(
c_m_n_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
M0
,
M1
>>
{},
UnMerge
<
Sequence
<
N0
,
N1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_global_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
.
GetLengths
()),
CThreadCopySrcDstAccessOrder
,
CThreadCopySrcDstVectorReadWriteDim
,
1
,
CThreadCopyDstDataPerWrite
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
>
(
{
0
,
0
,
0
,
0
},
{
m_thread_data_on_global
/
M1
,
m_thread_data_on_global
%
M1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
%
N1
})
.
Run
(
p_c_thread
,
p_c_global
);
}
}
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_b_global
,
Float
*
__restrict__
p_c_global
)
const
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
Float
);
__shared__
Float
p_shared_block
[
shared_block_size
];
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
p_shared_block
);
}
};
}
// namespace ck
#endif
driver/CMakeLists.txt
View file @
84239246
...
...
@@ -16,15 +16,14 @@ install(TARGETS host LIBRARY DESTINATION lib)
if
(
DEVICE_BACKEND STREQUAL
"AMD"
)
set
(
CONV_SOURCE src/conv_driver.cpp
)
set
(
COL2IM_SOURCE src/col2im_driver.cpp
)
set
(
CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cpp
)
elseif
(
DEVICE_BACKEND STREQUAL
"NVIDIA"
)
set
(
CONV_SOURCE src/conv_driver.cu
)
set
(
COL2IM_SOURCE src/col2im_driver.cu
)
set
(
CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cu
)
endif
()
add_executable
(
conv_driver
${
CONV_SOURCE
}
)
add_executable
(
conv_bwd_data_driver
${
CONV_BWD_DATA_SOURCE
}
)
target_link_libraries
(
conv_driver PRIVATE host
)
target_link_libraries
(
conv_bwd_data_driver PRIVATE host
)
driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
84239246
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
namespace
launcher
{
using
namespace
ck
;
template
<
typename
T
,
typename
InDesc
,
typename
WeiDesc
,
typename
OutDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw
(
InDesc
in_nchw_desc
,
Tensor
<
T
>&
in_nchw
,
WeiDesc
wei_kcyx_desc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
out_nkhw_desc
,
const
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
std
::
size_t
nrepeat
)
{
constexpr
index_t
N
=
out_nkhw_desc
.
GetLengths
()[
0
];
constexpr
index_t
K
=
out_nkhw_desc
.
GetLengths
()[
1
];
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLengths
()[
3
];
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM
=
Sequence
<
1
,
1
,
4
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM
=
Sequence
<
1
,
1
,
2
,
128
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN
=
Sequence
<
1
,
1
,
4
,
1
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN
=
Sequence
<
1
,
1
,
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
constexpr
index_t
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
constexpr
index_t
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
WTilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
constexpr
index_t
HTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
YTilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
WTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
XTilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
HTildaRight
=
math
::
min
(
HTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
WTildaRight
=
math
::
min
(
WTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
HTildaSlice
=
HTildaRight
-
HTildaLeft
;
constexpr
index_t
WTildaSlice
=
WTildaRight
-
WTildaLeft
;
constexpr
index_t
GemmM
=
C
;
constexpr
index_t
GemmN
=
N
*
HTildaSlice
*
WTildaSlice
;
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
using
GridwiseConvBwdData
=
GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
;
static_for
<
0
,
GridwiseConvBwdData
::
GetNumberOfGemm
(),
1
>
{}([
&
](
auto
gemm_id
)
{
constexpr
auto
gemm_sizes
=
GridwiseConvBwdData
::
GetGemmSize
(
gemm_id
);
constexpr
index_t
gemm_k
=
gemm_sizes
.
At
(
2
);
constexpr
bool
is_gemm_not_empty
=
gemm_k
>
0
;
// only compile and run if GEMM is no empty
static_if
<
is_gemm_not_empty
>
{}([
&
](
auto
fwd
)
{
launch_kernel
(
run_gridwise_operation
<
GridwiseConvBwdData
,
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
decltype
(
gemm_id
)
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()),
fwd
(
gemm_id
));
});
});
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
}
}
// namespace launcher
driver/src/conv_bwd_data_driver.cpp
View file @
84239246
...
...
@@ -18,6 +18,7 @@
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -172,7 +173,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif
0
#elif
1
// 7x1 filter, 3x0 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
...
...
@@ -187,13 +188,13 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif
1
#elif
0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
...
...
@@ -255,6 +256,8 @@ int main(int argc, char* argv[])
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
#elif 1
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#elif 1
device_convolution_backward_data_implicit_gemm_v5r1_nchw_kcyx_nkhw
#endif
(
in_nchw_desc
,
in_nchw_device
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment