Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
81c942cd
"vscode:/vscode.git/clone" did not exist on "121693b3d3b3148010f0756c5ab4741476620aba"
Unverified
Commit
81c942cd
authored
Jul 08, 2021
by
Chao Liu
Committed by
GitHub
Jul 08, 2021
Browse files
Deprecate static kernel (#42)
* deprecate static kernels
parent
b8b2d0a6
Changes
55
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
5789 deletions
+0
-5789
CMakeLists.txt
CMakeLists.txt
+0
-17
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+0
-172
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
...a_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
+0
-450
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+0
-418
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp
...ution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp
+0
-406
composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...d_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+0
-454
composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+0
-171
composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+0
-162
composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp
...l/include/tensor_description/ConstantMatrixDescriptor.hpp
+0
-80
composable_kernel/include/tensor_description/cluster_descriptor.hpp
..._kernel/include/tensor_description/cluster_descriptor.hpp
+0
-41
composable_kernel/include/tensor_description/dimension.hpp
composable_kernel/include/tensor_description/dimension.hpp
+0
-17
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+0
-523
composable_kernel/include/tensor_description/print_tensor_descriptor.hpp
...el/include/tensor_description/print_tensor_descriptor.hpp
+0
-173
composable_kernel/include/tensor_description/tensor_coordinate.hpp
...e_kernel/include/tensor_description/tensor_coordinate.hpp
+0
-289
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+0
-526
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
...l/include/tensor_description/tensor_descriptor_helper.hpp
+0
-176
composable_kernel/include/tensor_operation/blockwise_batched_gemm.hpp
...ernel/include/tensor_operation/blockwise_batched_gemm.hpp
+0
-406
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
...osable_kernel/include/tensor_operation/blockwise_gemm.hpp
+0
-334
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+0
-189
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+0
-785
No files found.
CMakeLists.txt
View file @
81c942cd
...
@@ -38,8 +38,6 @@ link_libraries(${OpenMP_pthread_LIBRARY})
...
@@ -38,8 +38,6 @@ link_libraries(${OpenMP_pthread_LIBRARY})
#GPU backend
#GPU backend
if
(
DEVICE_BACKEND STREQUAL
"AMD"
)
if
(
DEVICE_BACKEND STREQUAL
"AMD"
)
find_package
(
HIP REQUIRED
)
find_package
(
HIP REQUIRED
)
elseif
(
DEVICE_BACKEND STREQUAL
"NVIDIA"
)
enable_language
(
CUDA
)
endif
()
endif
()
#
#
...
@@ -64,13 +62,7 @@ endif()
...
@@ -64,13 +62,7 @@ endif()
if
(
DEVICE_BACKEND STREQUAL
"AMD"
)
if
(
DEVICE_BACKEND STREQUAL
"AMD"
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility/config.amd.hpp.in"
"
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility/config.hpp"
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility/config.amd.hpp.in"
"
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility/config.hpp"
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility/float_type.amd.hpp.in"
"
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility/float_type.hpp"
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility/float_type.amd.hpp.in"
"
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility/float_type.hpp"
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility/in_memory_operation.amd.hpp.in"
"
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility/in_memory_operation.hpp"
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility/synchronization.amd.hpp.in"
"
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility/synchronization.hpp"
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility/synchronization.amd.hpp.in"
"
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility/synchronization.hpp"
)
elseif
(
DEVICE_BACKEND STREQUAL
"NVIDIA"
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility/config.nvidia.hpp.in"
"
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility/config.hpp"
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility/float_type.nvidia.hpp.in"
"
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility/float_type.hpp"
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in"
"
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility/in_memory_operation.hpp"
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/utility/synchronization.nvidia.hpp.in"
"
${
PROJECT_BINARY_DIR
}
/composable_kernel/include/utility/synchronization.hpp"
)
endif
()
endif
()
add_subdirectory
(
driver
)
add_subdirectory
(
driver
)
...
@@ -80,26 +72,17 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
...
@@ -80,26 +72,17 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
message
(
"Compiling options for drivers:
${
CMAKE_CXX_FLAGS
}
"
)
message
(
"Compiling options for drivers:
${
CMAKE_CXX_FLAGS
}
"
)
if
(
DEVICE_BACKEND STREQUAL
"AMD"
)
if
(
DEVICE_BACKEND STREQUAL
"AMD"
)
set
(
CONV_SOURCE driver/conv_driver.cpp
)
set
(
CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cpp
)
set
(
CONV_V2_SOURCE driver/conv_driver_v2.cpp
)
set
(
CONV_V2_SOURCE driver/conv_driver_v2.cpp
)
set
(
CONV_BWD_DATA_V2_SOURCE driver/conv_bwd_data_driver_v2.cpp
)
set
(
CONV_BWD_DATA_V2_SOURCE driver/conv_bwd_data_driver_v2.cpp
)
set
(
CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp
)
set
(
CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp
)
elseif
(
DEVICE_BACKEND STREQUAL
"NVIDIA"
)
set
(
CONV_SOURCE driver/conv_driver.cu
)
set
(
CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cu
)
endif
()
endif
()
add_executable
(
conv_driver
${
CONV_SOURCE
}
)
add_executable
(
conv_bwd_data_driver
${
CONV_BWD_DATA_SOURCE
}
)
add_executable
(
conv_driver_v2
${
CONV_V2_SOURCE
}
)
add_executable
(
conv_driver_v2
${
CONV_V2_SOURCE
}
)
add_executable
(
conv_bwd_data_driver_v2
${
CONV_BWD_DATA_V2_SOURCE
}
)
add_executable
(
conv_bwd_data_driver_v2
${
CONV_BWD_DATA_V2_SOURCE
}
)
add_executable
(
conv_driver_v2_olc
${
CONV_V2_OLC_SOURCE
}
)
add_executable
(
conv_driver_v2_olc
${
CONV_V2_OLC_SOURCE
}
)
target_include_directories
(
conv_driver_v2_olc PRIVATE driver/olCompiling/include/
)
target_include_directories
(
conv_driver_v2_olc PRIVATE driver/olCompiling/include/
)
target_link_libraries
(
conv_driver PRIVATE modConv
)
target_link_libraries
(
conv_bwd_data_driver PRIVATE modConv
)
target_link_libraries
(
conv_driver_v2 PRIVATE modConv
)
target_link_libraries
(
conv_driver_v2 PRIVATE modConv
)
target_link_libraries
(
conv_bwd_data_driver_v2 PRIVATE modConv
)
target_link_libraries
(
conv_bwd_data_driver_v2 PRIVATE modConv
)
target_link_libraries
(
conv_driver_v2_olc PRIVATE modConv
)
target_link_libraries
(
conv_driver_v2_olc PRIVATE modConv
)
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
// GemmM = C * Y * X
// GemmN = N * Ho * Wo
// GemmK = K
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
ThreadGemmDataPerRead_GemmM
,
index_t
ThreadGemmDataPerRead_GemmN
,
typename
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmM
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
{
__device__
void
Run
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
const
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
//\todo static_assert for global vector load/store
// statc_assert();
// weight tensor
constexpr
auto
wei_gemmk_gemmm_global_desc
=
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
);
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
constexpr
auto
out_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
unfold_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
I2
,
I3
),
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
*
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
// \todo there are more combinations of Y, ConvDilationH and ConvStrideH that don't need
// atomic, find out all of them
constexpr
bool
not_need_atomic
=
(
ConvStrideH
>=
ConvDilationH
*
(
Y
-
1
)
+
1
)
and
(
ConvStrideW
>=
ConvDilationW
*
(
X
-
1
)
+
1
);
constexpr
auto
in_memory_op
=
not_need_atomic
?
InMemoryDataOperation
::
Set
:
InMemoryDataOperation
::
AtomicAdd
;
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
out_gemmk_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
in_memory_op
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
ThreadGemmDataPerRead_GemmM
,
ThreadGemmDataPerRead_GemmN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
EPerBlock
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
typename
OutBlockCopySubLengths_K_B_N0
,
typename
OutBlockCopyClusterLengths_K_B_N0
,
index_t
OutBlockCopySrcDataPerRead_B
,
index_t
OutBlockCopyDstDataPerWrite_N0
,
typename
WeiBlockCopySubLengths_K_E_C0
,
typename
WeiBlockCopyClusterLengths_K_E_C0
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_C0
,
index_t
InThreadCopyDstDataPerWrite_B
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
C0
=
GemmMPerThread
;
constexpr
index_t
N0
=
GemmNPerThread
;
static_assert
(
C
%
C0
==
0
&&
N
%
N0
==
0
,
"wrong!"
);
constexpr
index_t
C1
=
C
/
C0
;
constexpr
index_t
N1
=
N
/
N0
;
constexpr
index_t
E
=
C1
*
Y
*
X
;
constexpr
index_t
B
=
N1
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InThreadCopyDstDataPerWrite_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InThreadCopyDstDataPerWrite_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
E
%
EPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
EBlockWork
=
E
/
EPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
EBlockWork
,
BBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
e_block_data_on_global
=
block_work_id
[
Number
<
0
>
{}]
*
EPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_id
[
Number
<
1
>
{}]
*
BPerBlock
;
// output tensor
// global tensor in global memory, src of blockwise copy
constexpr
auto
out_n_k_howo_global_desc
=
unfold_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
I2
,
I3
);
constexpr
auto
out_n0_n1_k_howo_global_desc
=
transform_tensor_descriptor
(
out_n_k_howo_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
N0
,
N1
>>
{},
PassThrough
<
K
>
{},
PassThrough
<
Ho
*
Wo
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
constexpr
auto
out_k_b_n0_global_desc
=
transform_tensor_descriptor
(
out_n0_n1_k_howo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N1
,
Ho
*
Wo
>>
{},
PassThrough
<
N0
>
{}),
make_tuple
(
Sequence
<
2
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
out_k_b_n0_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
BPerBlock
,
N0
>
{},
Number
<
OutBlockCopyDstDataPerWrite_N0
>
{});
// output tensor blockwise copy
auto
blockwise_out_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
out_k_b_n0_global_desc
),
decltype
(
out_k_b_n0_block_desc
),
decltype
(
out_k_b_n0_block_desc
.
GetLengths
()),
OutBlockCopySubLengths_K_B_N0
,
OutBlockCopyClusterLengths_K_B_N0
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
1
,
2
,
OutBlockCopySrcDataPerRead_B
,
OutBlockCopyDstDataPerWrite_N0
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
make_multi_index
(
0
,
b_block_data_on_global
,
0
),
make_multi_index
(
0
,
0
,
0
));
// weight tensor
// global tensor in global memory, src of blockwise copy
constexpr
auto
wei_k_cyx_global_desc
=
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
);
constexpr
auto
wei_k_c0_e_global_desc
=
transform_tensor_descriptor
(
wei_k_cyx_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
UnMerge
<
Sequence
<
C0
,
E
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
constexpr
auto
wei_k_e_c0_global_desc
=
reorder_tensor_descriptor_given_lower2upper
(
wei_k_c0_e_global_desc
,
Sequence
<
0
,
2
,
1
>
{});
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_k_e_c0_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
EPerBlock
,
C0
>
{},
Number
<
WeiBlockCopyDstDataPerWrite_C0
>
{});
// weight tensor blockwise copy
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
wei_k_e_c0_global_desc
),
decltype
(
wei_k_e_c0_block_desc
),
decltype
(
wei_k_e_c0_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_K_E_C0
,
WeiBlockCopyClusterLengths_K_E_C0
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
1
,
2
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_C0
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
make_multi_index
(
0
,
e_block_data_on_global
,
0
),
make_multi_index
(
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, EPerBlock*C0] is in LDS
// b_mtx[KPerBlocl, BPerBlock*N0] is in LDS
// c_mtx[EPerBlock*C0, BPerBlock*N0] is distributed among threads, and saved in
// register
constexpr
auto
a_k_ec0_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_k_e_c0_block_desc
.
GetLength
(
I0
),
wei_k_e_c0_block_desc
.
GetLength
(
I1
)
*
wei_k_e_c0_block_desc
.
GetLength
(
I2
),
wei_k_e_c0_block_desc
.
GetStride
(
I0
));
constexpr
auto
b_k_bn0_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
out_k_b_n0_block_desc
.
GetLength
(
I0
),
out_k_b_n0_block_desc
.
GetLength
(
I1
)
*
out_k_b_n0_block_desc
.
GetLength
(
I2
),
out_k_b_n0_block_desc
.
GetStride
(
I0
));
// sanity check alignment
// TODO: this check is ad-hoc, should enforce it by enforcing alignment of
// wei_k_e_c0_block_desc and out_k_b_n0_block_desc
static_assert
(
a_k_ec0_block_mtx_desc
.
RowStride
()
%
GemmDataPerReadB
==
0
,
"wrong!"
);
static_assert
(
b_k_bn0_block_mtx_desc
.
RowStride
()
%
GemmDataPerReadA
==
0
,
"wrong!"
);
// sanity check
static_assert
(
EPerBlock
%
(
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
&&
BPerBlock
%
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
EPerBlock
/
(
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
BPerBlock
/
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_e0e1c0_b0b1n0_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThread
>
{},
Number
<
GemmNRepeat
*
GemmNPerThread
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_k_ec0_block_mtx_desc
),
decltype
(
b_k_bn0_block_mtx_desc
),
decltype
(
c_e0e1c0_b0b1n0_thread_mtx_desc
),
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_lds_align
=
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_C0
,
OutBlockCopyDstDataPerWrite_N0
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
out_block_space
=
math
::
integer_least_multiple
(
out_k_b_n0_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_k_e_c0_block_desc
.
GetElementSpace
(),
max_lds_align
);
__shared__
Float
p_out_block_double
[
2
*
out_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
AccFloat
p_in_thread
[
c_e0e1c0_b0b1n0_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_e0e1c0_b0b1n0_thread_mtx_desc
,
p_in_thread
);
// LDS double buffer: preload data into LDS
{
blockwise_out_copy
.
Run
(
p_out_global
,
p_out_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_out_block_now
=
even_loop
?
p_out_block_double
:
p_out_block_double
+
out_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_out_block_next
=
even_loop
?
p_out_block_double
+
out_block_space
:
p_out_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_out_thread_buffer
[
blockwise_out_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_out_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
,
0
>
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_out_copy
.
RunLoadThreadBuffer
(
p_out_global
,
p_out_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_out_block_now
,
p_in_thread
);
// LDS double buffer: store next data to LDS
blockwise_out_copy
.
RunStoreThreadBuffer
(
p_out_thread_buffer
,
p_out_block_next
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_out_thread_buffer
[
blockwise_out_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_out_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
,
0
>
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
blockwise_out_copy
.
RunLoadThreadBuffer
(
p_out_global
,
p_out_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_out_block_double
,
p_in_thread
);
// LDS double buffer: store last data to LDS
blockwise_out_copy
.
RunStoreThreadBuffer
(
p_out_thread_buffer
,
p_out_block_double
+
out_block_space
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_double
+
wei_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_out_block_double
+
out_block_space
,
p_in_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_out_block_double
,
p_in_thread
);
}
}
{
#if 1 // debug
// input: register to global memory, atomic add
constexpr
auto
in_memory_op
=
(
Y
<=
ConvStrideH
&&
X
<=
ConvStrideW
)
?
InMemoryDataOperation
::
Set
:
InMemoryDataOperation
::
AtomicAdd
;
#else
constexpr
auto
in_memory_op
=
InMemoryDataOperation
::
AtomicAdd
;
#endif
constexpr
index_t
E1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
E0
=
E
/
E1
;
constexpr
index_t
B1
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
B0
=
B
/
B1
;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr
auto
in_e0_e1_c0_b0_b1_n0_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
1
,
GemmMPerThread
,
GemmNRepeat
,
1
,
GemmNPerThread
>
{});
// global input tensor, dst of threadwise copy
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n0_n1_c0_c1_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
N0
,
N1
>>
{},
UnMerge
<
Sequence
<
C0
,
C1
>>
{},
Embed
<
Hi
+
LeftPads
::
At
(
0
)
+
RightPads
::
At
(
0
),
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wi
+
LeftPads
::
At
(
1
)
+
RightPads
::
At
(
1
),
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
,
7
>
{}));
constexpr
auto
in_e_c0_b_n0_global_desc
=
transform_tensor_descriptor
(
in_n0_n1_c0_c1_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C1
,
Y
,
X
>>
{},
PassThrough
<
C0
>
{},
Merge
<
Sequence
<
N1
,
Ho
,
Wo
>>
{},
PassThrough
<
N0
>
{}),
make_tuple
(
Sequence
<
3
,
4
,
6
>
{},
Sequence
<
2
>
{},
Sequence
<
1
,
5
,
7
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
constexpr
auto
in_e0_e1_c0_b0_b1_n0_global_desc
=
transform_tensor_descriptor
(
in_e_c0_b_n0_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
E0
,
E1
>>
{},
PassThrough
<
C0
>
{},
UnMerge
<
Sequence
<
B0
,
B1
>>
{},
PassThrough
<
N0
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
e_thread_data_on_global
=
e_block_data_on_global
+
c_thread_mtx_on_block
.
row
/
GemmMPerThread
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
GemmNPerThread
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
in_e0_e1_c0_b0_b1_n0_thread_desc
),
decltype
(
in_e0_e1_c0_b0_b1_n0_global_desc
),
decltype
(
in_e0_e1_c0_b0_b1_n0_thread_desc
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
4
,
1
,
InThreadCopyDstDataPerWrite_B
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
in_memory_op
>
(
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
e_thread_data_on_global
/
E1
,
e_thread_data_on_global
%
E1
,
0
,
b_thread_data_on_global
/
B1
,
b_thread_data_on_global
%
B1
,
0
))
.
Run
(
p_in_thread
,
p_in_global
);
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
// Number of GEMMs: YTilda * XTilda
// GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice
// GemmK = K * YDotSlice * XDotSlice
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
ThreadGemmDataPerRead_GemmM
,
index_t
ThreadGemmDataPerRead_GemmN
,
typename
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmM
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
{
__host__
__device__
static
constexpr
index_t
GetNumberOfGemm
()
{
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
return
YTilda
*
XTilda
;
}
__host__
__device__
static
constexpr
auto
GetGemmSizeImpl
(
index_t
iYTilda
,
index_t
iXTilda
)
{
constexpr
index_t
N
=
InGlobalDesc
::
GetLengths
()[
0
];
constexpr
index_t
C
=
InGlobalDesc
::
GetLengths
()[
1
];
constexpr
index_t
Hi
=
InGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
Wi
=
InGlobalDesc
::
GetLengths
()[
3
];
constexpr
index_t
K
=
OutGlobalDesc
::
GetLengths
()[
1
];
constexpr
index_t
Ho
=
OutGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
Wo
=
OutGlobalDesc
::
GetLengths
()[
3
];
constexpr
index_t
Y
=
WeiGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
X
=
WeiGlobalDesc
::
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
constexpr
index_t
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
constexpr
index_t
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
WTilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
constexpr
index_t
iHTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
YTilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
iWTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
XTilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
iHTildaRight
=
math
::
min
(
HTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
iWTildaRight
=
math
::
min
(
WTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
HTildaSlice
=
iHTildaRight
-
iHTildaLeft
;
constexpr
index_t
WTildaSlice
=
iWTildaRight
-
iWTildaLeft
;
// GemmM and GemmN
constexpr
index_t
GemmM
=
C
;
constexpr
index_t
GemmN
=
N
*
HTildaSlice
*
WTildaSlice
;
// GemmK is different for each GEMM
index_t
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
iYTilda
,
YTilda
);
index_t
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
iXTilda
,
XTilda
);
index_t
GemmK
=
K
*
YDotSlice
*
XDotSlice
;
return
Array
<
index_t
,
3
>
{
GemmM
,
GemmN
,
GemmK
};
}
__host__
__device__
static
constexpr
auto
GetGemmSize
(
index_t
gemm_id
)
{
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
index_t
iYTilda
=
gemm_id
/
XTilda
;
index_t
iXTilda
=
gemm_id
%
XTilda
;
return
GetGemmSizeImpl
(
iYTilda
,
iXTilda
);
}
template
<
index_t
iYTilda
,
index_t
iXTilda
>
__device__
static
void
RunImpl
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
{
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
constexpr
index_t
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
constexpr
index_t
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
iYTilda
,
YTilda
);
constexpr
index_t
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
iXTilda
,
XTilda
);
constexpr
index_t
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
WTilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
constexpr
index_t
iHTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
YTilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
iWTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
XTilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
iHTildaRight
=
math
::
min
(
HTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
iWTildaRight
=
math
::
min
(
WTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
HTildaSlice
=
iHTildaRight
-
iHTildaLeft
;
constexpr
index_t
WTildaSlice
=
iWTildaRight
-
iWTildaLeft
;
// A matrix: weight
// weight out-of-bound check can be skipped
constexpr
bool
wei_skip_out_of_bound_check
=
true
;
constexpr
auto
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
=
transform_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Embed
<
Y
,
Sequence
<
YDot
,
YTilda
>
,
Sequence
<
ConvStrideH
/
GcdStrideDilationH
,
1
,
0
>
,
wei_skip_out_of_bound_check
>
{},
Embed
<
X
,
Sequence
<
XDot
,
XTilda
>
,
Sequence
<
ConvStrideW
/
GcdStrideDilationW
,
1
,
0
>
,
wei_skip_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
wei_k_c_ydotslice_xdotslice_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{},
Freeze
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydotslice_xdotslice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YDotSlice
,
XDotSlice
>>
{},
PassThrough
<
C
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
3
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B matrix: output tensor
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
// situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr
bool
out_skip_out_of_bound_check
=
false
;
#else
constexpr
bool
out_skip_out_of_bound_check
=
true
;
#endif
constexpr
auto
out_n_k_ydot_htilda_xdot_wtilda_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
Embed
<
Ho
,
Sequence
<
YDot
,
HTilda
>
,
Sequence
<-
ConvDilationH
/
GcdStrideDilationH
,
1
,
0
>
,
out_skip_out_of_bound_check
>
{},
Embed
<
Wo
,
Sequence
<
XDot
,
WTilda
>
,
Sequence
<-
ConvDilationW
/
GcdStrideDilationW
,
1
,
0
>
,
out_skip_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htilda_xdot_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
YDot
>
{},
PassThrough
<
XDot
>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
HTildaSlice
>
{},
PassThrough
<
WTildaSlice
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
out_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YDotSlice
,
XDotSlice
>>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// C matrix: input tensor
// TODO sometimes input out-of-bound check can be skipped, find out all such situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr
bool
in_skip_out_of_bound_check
=
false
;
#else
constexpr
bool
in_skip_out_of_bound_check
=
true
;
#endif
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
,
in_skip_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
YTilda
,
HTilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>
,
in_skip_out_of_bound_check
>
{},
Embed
<
Wip
,
Sequence
<
XTilda
,
WTilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>
,
in_skip_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_n_c_htildaslice_wtildaslice_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Freeze
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_htildaslice_wtildaslice_global_desc
,
make_tuple
(
PassThrough
<
C
>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
out_gemmk_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
ThreadGemmDataPerRead_GemmM
,
ThreadGemmDataPerRead_GemmN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
);
}
template
<
index_t
GemmId
>
__device__
static
void
Run
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
,
Number
<
GemmId
>
)
{
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
iYTilda
=
GemmId
/
XTilda
;
constexpr
index_t
iXTilda
=
GemmId
%
XTilda
;
static_assert
(
iYTilda
<
YTilda
&&
iXTilda
<
XTilda
,
"wrong! iYtilda, iXtilda"
);
RunImpl
<
iYTilda
,
iXTilda
>
(
p_in_global
,
p_wei_global
,
p_out_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
// Number of GEMMs = YTilda * XTilda
// GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice
// GemmK0 = YDotSlice
// GemmK1 = XDotSlice
// GemmK2 = K
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
ThreadGemmDataPerRead_GemmM
,
index_t
ThreadGemmDataPerRead_GemmN
,
typename
GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM
,
typename
GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmM
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmK2
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk
{
__host__
__device__
static
constexpr
index_t
GetNumberOfGemm
()
{
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
return
YTilda
*
XTilda
;
}
__host__
__device__
static
constexpr
auto
GetGemmSizeImpl
(
index_t
iYTilda
,
index_t
iXTilda
)
{
constexpr
index_t
N
=
InGlobalDesc
::
GetLengths
()[
0
];
constexpr
index_t
Hi
=
InGlobalDesc
::
GetLengths
()[
1
];
constexpr
index_t
Wi
=
InGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
C
=
InGlobalDesc
::
GetLengths
()[
3
];
constexpr
index_t
Ho
=
OutGlobalDesc
::
GetLengths
()[
1
];
constexpr
index_t
Wo
=
OutGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
K
=
OutGlobalDesc
::
GetLengths
()[
3
];
constexpr
index_t
Y
=
WeiGlobalDesc
::
GetLengths
()[
1
];
constexpr
index_t
X
=
WeiGlobalDesc
::
GetLengths
()[
2
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
constexpr
index_t
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
constexpr
index_t
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
WTilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
constexpr
index_t
iHTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
YTilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
iWTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
XTilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
iHTildaRight
=
math
::
min
(
HTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
iWTildaRight
=
math
::
min
(
WTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
HTildaSlice
=
iHTildaRight
-
iHTildaLeft
;
constexpr
index_t
WTildaSlice
=
iWTildaRight
-
iWTildaLeft
;
// GemmM and GemmN
constexpr
index_t
GemmM
=
C
;
constexpr
index_t
GemmN
=
N
*
HTildaSlice
*
WTildaSlice
;
// GemmK is different for each GEMM
index_t
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
iYTilda
,
YTilda
);
index_t
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
iXTilda
,
XTilda
);
index_t
GemmK0
=
YDotSlice
;
index_t
GemmK1
=
XDotSlice
;
index_t
GemmK2
=
K
;
return
make_multi_index
(
GemmM
,
GemmN
,
GemmK0
,
GemmK1
,
GemmK2
);
}
__host__
__device__
static
constexpr
auto
GetGemmSize
(
index_t
gemm_id
)
{
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
index_t
iYTilda
=
gemm_id
/
XTilda
;
index_t
iXTilda
=
gemm_id
%
XTilda
;
return
GetGemmSizeImpl
(
iYTilda
,
iXTilda
);
}
template
<
index_t
iYTilda
,
index_t
iXTilda
>
__device__
static
void
RunImpl
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
{
constexpr
auto
in_n_hi_wi_c_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_y_x_c_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_ho_wo_k_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_hi_wi_c_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
Hi
=
in_n_hi_wi_c_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Wi
=
in_n_hi_wi_c_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
C
=
in_n_hi_wi_c_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Ho
=
out_n_ho_wo_k_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Wo
=
out_n_ho_wo_k_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
K
=
out_n_ho_wo_k_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_y_x_c_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
X
=
wei_k_y_x_c_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
constexpr
index_t
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
constexpr
index_t
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
iYTilda
,
YTilda
);
constexpr
index_t
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
iXTilda
,
XTilda
);
constexpr
index_t
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
WTilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
1
),
ConvStrideW
);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
constexpr
index_t
iHTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
0
]
-
ConvDilationH
*
(
YTilda
-
1
)),
ConvStrides
{}[
0
]);
constexpr
index_t
iWTildaLeft
=
math
::
integer_divide_floor
(
math
::
max
(
0
,
InLeftPads
{}[
1
]
-
ConvDilationW
*
(
XTilda
-
1
)),
ConvStrides
{}[
1
]);
constexpr
index_t
iHTildaRight
=
math
::
min
(
HTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
0
]
+
Hi
-
1
,
ConvStrides
{}[
0
])
+
1
);
constexpr
index_t
iWTildaRight
=
math
::
min
(
WTilda
,
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
,
ConvStrides
{}[
1
])
+
1
);
constexpr
index_t
HTildaSlice
=
iHTildaRight
-
iHTildaLeft
;
constexpr
index_t
WTildaSlice
=
iWTildaRight
-
iWTildaLeft
;
// A matrix: weight
// weight out-of-bound check can be skipped
constexpr
bool
wei_skip_out_of_bound_check
=
true
;
constexpr
auto
wei_k_ydot_ytilda_xdot_xtilda_c_global_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Embed
<
Y
,
Sequence
<
YDot
,
YTilda
>
,
Sequence
<
ConvStrideH
/
GcdStrideDilationH
,
1
,
0
>
,
wei_skip_out_of_bound_check
>
{},
Embed
<
X
,
Sequence
<
XDot
,
XTilda
>
,
Sequence
<
ConvStrideW
/
GcdStrideDilationW
,
1
,
0
>
,
wei_skip_out_of_bound_check
>
{},
PassThrough
<
C
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
wei_k_ydotslice_xdotslice_c_global_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilda_xdot_xtilda_c_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{},
Freeze
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>>
{},
PassThrough
<
C
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<>
{},
Sequence
<
3
>
{}));
constexpr
auto
wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc
=
reorder_tensor_descriptor_given_lower2upper
(
wei_k_ydotslice_xdotslice_c_global_desc
,
Sequence
<
2
,
0
,
1
,
3
>
{});
// B matrix: output tensor
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
// situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr
bool
out_skip_out_of_bound_check
=
false
;
#else
constexpr
bool
out_skip_out_of_bound_check
=
true
;
#endif
constexpr
auto
out_n_ydot_htilda_xdot_wtilda_k_global_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
Embed
<
Ho
,
Sequence
<
YDot
,
HTilda
>
,
Sequence
<-
ConvDilationH
/
GcdStrideDilationH
,
1
,
0
>
,
out_skip_out_of_bound_check
>
{},
Embed
<
Wo
,
Sequence
<
XDot
,
WTilda
>
,
Sequence
<-
ConvDilationW
/
GcdStrideDilationW
,
1
,
0
>
,
out_skip_out_of_bound_check
>
{},
PassThrough
<
K
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc
=
transform_tensor_descriptor
(
out_n_ydot_htilda_xdot_wtilda_k_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{},
PassThrough
<
K
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
out_gemmk0_gemmk1_gemmk2_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc
,
make_tuple
(
PassThrough
<
YDotSlice
>
{},
PassThrough
<
XDotSlice
>
{},
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// C matrix: input tensor
// TODO sometimes input out-of-bound check can be skipped, find out all such situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr
bool
in_skip_out_of_bound_check
=
false
;
#else
constexpr
bool
in_skip_out_of_bound_check
=
true
;
#endif
constexpr
auto
in_n_hip_wip_c_global_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
,
in_skip_out_of_bound_check
>
{},
PassThrough
<
C
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
constexpr
index_t
Hip
=
in_n_hip_wip_c_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Wip
=
in_n_hip_wip_c_global_desc
.
GetLengths
()[
2
];
constexpr
auto
in_n_ytilda_htilda_xtilda_wtilda_c_global_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
Embed
<
Hip
,
Sequence
<
YTilda
,
HTilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>
,
in_skip_out_of_bound_check
>
{},
Embed
<
Wip
,
Sequence
<
XTilda
,
WTilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>
,
in_skip_out_of_bound_check
>
{},
PassThrough
<
C
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
in_n_htildaslice_wtildaslice_c_global_desc
=
transform_tensor_descriptor
(
in_n_ytilda_htilda_xtilda_wtilda_c_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
Freeze
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>>
{},
Slice
<
Sequence
<
HTilda
,
WTilda
>
,
Sequence
<
iHTildaLeft
,
iWTildaLeft
>
,
Sequence
<
iHTildaRight
,
iWTildaRight
>>
{},
PassThrough
<
C
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_htildaslice_wtildaslice_c_global_desc
,
make_tuple
(
PassThrough
<
C
>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// call GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v2
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc
),
decltype
(
out_gemmk0_gemmk1_gemmk2_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
ThreadGemmDataPerRead_GemmM
,
ThreadGemmDataPerRead_GemmN
,
GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN
,
Sequence
<
0
,
1
,
3
,
2
>
,
Sequence
<
0
,
1
,
3
,
2
>
,
2
,
GemmBBlockCopySrcDataPerRead_GemmK2
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
);
}
template
<
index_t
GemmId
>
__device__
static
void
Run
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
,
Number
<
GemmId
>
)
{
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
constexpr
index_t
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
constexpr
index_t
iYTilda
=
GemmId
/
XTilda
;
constexpr
index_t
iXTilda
=
GemmId
%
XTilda
;
static_assert
(
iYTilda
<
YTilda
&&
iXTilda
<
XTilda
,
"wrong! iYtilda, iXtilda"
);
RunImpl
<
iYTilda
,
iXTilda
>
(
p_in_global
,
p_wei_global
,
p_out_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccDataType
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
GemmNRepeat
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
typename
InBlockCopySubLengths_E_N1_B_N2
,
typename
InBlockCopyClusterLengths_E_N1_B_N2
,
typename
InBlockCopyThreadClusterArrangeOrder
,
typename
InBlockCopySrcAccessOrder
,
typename
InBlockCopyDstAccessOrder
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
typename
WeiBlockCopySubLengths_E_K
,
typename
WeiBlockCopyClusterLengths_E_K
,
typename
WeiBlockCopyThreadClusterArrangeOrder
,
typename
WeiBlockCopySrcAccessOrder
,
typename
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThread
;
static_assert
(
(
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
EPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_id
[
I0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_id
[
I1
]
*
BPerBlock
;
// input tensor
// global tensor in global memory
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n0_n1_n2_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
N0
,
N1
,
N2
>>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
,
7
>
{}));
// global tensor in global memory, src of blockwise copy
constexpr
auto
in_e_n1_b_n2_global_desc
=
transform_tensor_descriptor
(
in_n0_n1_n2_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
PassThrough
<
N1
>
{},
Merge
<
Sequence
<
N0
,
Ho
,
Wo
>>
{},
PassThrough
<
N2
>
{}),
make_tuple
(
Sequence
<
3
,
4
,
6
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
5
,
7
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_b_n2_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
EPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
in_e_n1_b_n2_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input tensor blockwise copy
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
in_e_n1_b_n2_global_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
2
,
3
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
make_multi_index
(
0
,
0
,
b_block_data_on_global
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
));
// weight tensor
// global tensor in global memory, src of blockwise copy
// It is constructed differently, depending on whether forward or backward weight
// convolution
constexpr
auto
wei_e_k_global_desc
=
transform_tensor_descriptor
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I2
,
I3
),
make_tuple
(
Merge
<
Sequence
<
C
,
Y
*
X
>>
{},
PassThrough
<
K
>
{}),
make_tuple
(
Sequence
<
1
,
2
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
wei_e_k_block_desc
.
GetStride
(
I0
)
%
GemmDataPerReadA
==
0
,
"GemmDataPerReadA alignment requirement is not satisfied"
);
// weight tensor blockwise copy
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
make_multi_index
(
0
,
k_block_data_on_global
),
make_multi_index
(
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_n1_b_n2_block_desc
.
GetLength
(
I0
),
in_e_n1_b_n2_block_desc
.
GetLength
(
I1
)
*
in_e_n1_b_n2_block_desc
.
GetLength
(
I2
)
*
in_e_n1_b_n2_block_desc
.
GetLength
(
I3
),
in_e_n1_b_n2_block_desc
.
GetStride
(
I0
));
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k1_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThread
>
{},
Number
<
GemmNRepeat
*
GemmNPerThread
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k1_n1n2_thread_mtx_desc
),
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_n1_b_n2_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
AccDataType
p_out_thread
[
c_k0k1_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k1_n1n2_thread_mtx_desc
,
p_out_thread
);
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
constexpr
auto
in_block_slice_copy_steps
=
Sequence
<
EPerBlock
,
0
,
0
,
0
>
{};
constexpr
auto
wei_block_slice_copy_steps
=
Sequence
<
EPerBlock
,
0
>
{};
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_thread_buffer
[
blockwise_in_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_in_copy
.
MoveSrcSliceWindow
(
in_block_slice_copy_steps
,
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
wei_block_slice_copy_steps
,
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreThreadBuffer
(
p_in_thread_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
E
%
(
2
*
EPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_in_thread_buffer
[
blockwise_in_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_in_copy
.
MoveSrcSliceWindow
(
in_block_slice_copy_steps
,
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
wei_block_slice_copy_steps
,
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store last data to LDS
blockwise_in_copy
.
RunStoreThreadBuffer
(
p_in_thread_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_double
+
wei_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
}
}
// copy output: register to global memory
{
constexpr
index_t
K1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
K0
=
K
/
K1
;
// define output tensor descriptor for threadwise copy
// thread output tensor, src of threadwise copy
constexpr
auto
out_k0_k1_n1_b_n2_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
GemmMPerThread
,
N1
,
1
,
N2
>
{});
// global output tensor
constexpr
auto
out_n0_n1_n2_k0_k1_ho_wo_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
N0
,
N1
,
N2
>>
{},
UnMerge
<
Sequence
<
K0
,
K1
>>
{},
PassThrough
<
Ho
>
{},
PassThrough
<
Wo
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}));
// global output tensor, dst of threadwise copy
constexpr
auto
out_k0_k1_n1_b_n2_global_desc
=
transform_tensor_descriptor
(
out_n0_n1_n2_k0_k1_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K0
>
{},
PassThrough
<
K1
>
{},
PassThrough
<
N1
>
{},
Merge
<
Sequence
<
N0
,
Ho
,
Wo
>>
{},
PassThrough
<
N2
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
5
,
6
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
out_k0_k1_n1_b_n2_thread_desc
),
decltype
(
out_k0_k1_n1_b_n2_global_desc
),
decltype
(
out_k0_k1_n1_b_n2_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
3
,
1
,
1
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
InMemoryDataOperation
::
Set
>
(
make_multi_index
(
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
k_thread_data_on_global
/
K1
,
k_thread_data_on_global
%
K1
,
0
,
b_thread_data_on_global
,
0
))
.
Run
(
p_out_thread
,
p_out_global
);
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
ThreadGemmDataPerRead_GemmM
,
index_t
ThreadGemmDataPerRead_GemmN
,
typename
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmK
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
#if 0
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0) &&
InLeftPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0 &&
InRightPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0,
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
// weight tensor
constexpr
auto
wei_gemmk_gemmm_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
constexpr
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
unfold_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
I2
,
I3
),
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
*
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
in_gemmk_gemmn_global_desc
),
decltype
(
out_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
ThreadGemmDataPerRead_GemmM
,
ThreadGemmDataPerRead_GemmN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockCopySrcDataPerRead_GemmK
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_in_global
,
p_out_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
ThreadGemmDataPerRead_GemmM
,
index_t
ThreadGemmDataPerRead_GemmN
,
typename
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmK
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmK
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmM1
>
struct
GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_hi_wi_c_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_y_x_c_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_ho_wo_k_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_hi_wi_c_global_desc
.
GetLengths
()[
I0
];
constexpr
index_t
Hi
=
in_n_hi_wi_c_global_desc
.
GetLengths
()[
I1
];
constexpr
index_t
Wi
=
in_n_hi_wi_c_global_desc
.
GetLengths
()[
I2
];
constexpr
index_t
C
=
in_n_hi_wi_c_global_desc
.
GetLengths
()[
I3
];
constexpr
index_t
K
=
out_n_ho_wo_k_global_desc
.
GetLengths
()[
I3
];
constexpr
index_t
Ho
=
out_n_ho_wo_k_global_desc
.
GetLengths
()[
I1
];
constexpr
index_t
Wo
=
out_n_ho_wo_k_global_desc
.
GetLengths
()[
I2
];
constexpr
index_t
Y
=
wei_k_y_x_c_global_desc
.
GetLengths
()[
I1
];
constexpr
index_t
X
=
wei_k_y_x_c_global_desc
.
GetLengths
()[
I2
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
I0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
I1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
I0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
I1
];
// weight tensor
constexpr
auto
wei_gemmk_gemmm_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_y_x_c_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
// input tensor
constexpr
auto
in_n_hip_wip_c_global_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
>
{},
PassThrough
<
C
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
constexpr
index_t
Hip
=
in_n_hip_wip_c_global_desc
.
GetLengths
()[
I1
];
constexpr
index_t
Wip
=
in_n_hip_wip_c_global_desc
.
GetLengths
()[
I2
];
constexpr
auto
in_n_y_ho_x_wo_c_global_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
Embed
<
Hip
,
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{},
PassThrough
<
C
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_global_desc
,
make_tuple
(
Merge
<
Sequence
<
Y
,
X
,
C
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
constexpr
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
unfold_tensor_descriptor
(
out_n_ho_wo_k_global_desc
,
I0
,
I2
),
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
*
Ho
*
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
in_gemmk_gemmn_global_desc
),
decltype
(
out_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
ThreadGemmDataPerRead_GemmM
,
ThreadGemmDataPerRead_GemmN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockCopySrcDataPerRead_GemmK
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmBBlockCopySrcDataPerRead_GemmK
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
2
,
3
,
0
,
1
>
,
1
,
GemmCThreadCopyDstDataPerWrite_GemmM1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_in_global
,
p_out_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
namespace
ck
{
template
<
index_t
NRow_
,
index_t
NCol_
,
index_t
RowStride_
>
struct
ConstantMatrixDescriptor
{
__host__
__device__
constexpr
ConstantMatrixDescriptor
()
{
static_assert
(
NCol_
<=
RowStride_
,
"wrong! NCol > RowStride!"
);
}
__host__
__device__
static
constexpr
index_t
NRow
()
{
return
NRow_
;
}
__host__
__device__
static
constexpr
index_t
NCol
()
{
return
NCol_
;
}
__host__
__device__
static
constexpr
index_t
RowStride
()
{
return
RowStride_
;
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
Sequence
<
NRow_
,
NCol_
>
{};
}
__host__
__device__
static
constexpr
index_t
GetElementSize
()
{
return
NRow_
*
NCol_
;
}
__host__
__device__
static
constexpr
index_t
GetElementSpace
()
{
return
NRow_
*
RowStride_
;
}
__host__
__device__
static
index_t
GetOffsetFromMultiIndex
(
index_t
irow
,
index_t
icol
)
{
return
irow
*
RowStride_
+
icol
;
}
__host__
__device__
static
index_t
CalculateOffset
(
index_t
irow
,
index_t
icol
)
{
return
GetOffsetFromMultiIndex
(
irow
,
icol
);
}
template
<
index_t
SubNRow
,
index_t
SubNCol
>
__host__
__device__
static
constexpr
auto
MakeSubMatrixDescriptor
(
Number
<
SubNRow
>
,
Number
<
SubNCol
>
)
{
return
ConstantMatrixDescriptor
<
SubNRow
,
SubNCol
,
RowStride_
>
{};
}
};
template
<
index_t
NRow
,
index_t
NCol
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor_packed
(
Number
<
NRow
>
,
Number
<
NCol
>
)
{
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
NCol
>
{};
}
template
<
index_t
NRow
,
index_t
NCol
,
index_t
RowStride
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
Number
<
NRow
>
,
Number
<
NCol
>
,
Number
<
RowStride
>
)
{
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
RowStride
>
{};
}
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
NativeTensorDescriptor
<
Ts
...
>
)
{
using
TDesc
=
NativeTensorDescriptor
<
Ts
...
>
;
static_assert
(
TDesc
::
GetNumOfDimension
()
==
2
,
"wrong"
);
static_assert
(
TDesc
::
GetStrides
()[
1
]
==
1
,
"wrong"
);
return
ConstantMatrixDescriptor
<
TDesc
::
GetLengths
()[
0
],
TDesc
::
GetLengths
()[
1
],
TDesc
::
GetStrides
()[
0
]
>
{};
}
template
<
typename
TDesc
>
__host__
__device__
void
print_ConstantMatrixDescriptor
(
TDesc
,
const
char
*
s
)
{
printf
(
"%s NRow %u NCol %u RowStride %u
\n
"
,
s
,
TDesc
::
NRow
(),
TDesc
::
NCol
(),
TDesc
::
RowStride
());
}
}
// namespace ck
#endif
composable_kernel/include/tensor_description/cluster_descriptor.hpp
View file @
81c942cd
...
@@ -2,50 +2,10 @@
...
@@ -2,50 +2,10 @@
#define CK_CLUSTER_DESCRIPTOR_HPP
#define CK_CLUSTER_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "common_header.hpp"
// TODO remove dependency on deprecated tensor descriptor
#include "tensor_descriptor.hpp"
#include "tensor_adaptor.hpp"
#include "tensor_adaptor.hpp"
namespace
ck
{
namespace
ck
{
// a cluster map 1d index to N-d index
template
<
typename
Lengths
,
typename
ArrangeOrder
>
struct
ClusterDescriptor
{
static
constexpr
index_t
nDim
=
Lengths
::
Size
();
static
constexpr
auto
mDesc
=
transform_tensor_descriptor
(
make_native_tensor_descriptor_packed
(
Lengths
{}),
make_tuple
(
Merge
<
decltype
(
Lengths
::
ReorderGivenNew2Old
(
ArrangeOrder
{}))
>
{}),
make_tuple
(
ArrangeOrder
{}),
make_tuple
(
Sequence
<
0
>
{}));
__host__
__device__
constexpr
ClusterDescriptor
()
{
static_assert
(
Lengths
::
Size
()
==
nDim
&&
ArrangeOrder
::
Size
()
==
nDim
,
"wrong! size not the same"
);
static_assert
(
is_valid_sequence_map
<
ArrangeOrder
>
{},
"wrong! ArrangeOrder is wrong"
);
}
__host__
__device__
static
constexpr
index_t
GetElementSize
()
{
return
mDesc
.
GetElementSize
();
}
__host__
__device__
static
constexpr
auto
CalculateClusterIndex
(
index_t
idx_1d
)
{
return
mDesc
.
CalculateLowerIndex
(
MultiIndex
<
1
>
{
idx_1d
});
}
};
template
<
typename
Lengths
,
typename
ArrangeOrder
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
Size
(),
1
>
::
type
>
__host__
__device__
constexpr
auto
make_cluster_descriptor
(
Lengths
,
ArrangeOrder
order
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
Size
(),
1
>::
type
{})
{
return
ClusterDescriptor
<
Lengths
,
decltype
(
order
)
>
{};
}
#if 1
template
<
typename
Lengths
,
template
<
typename
Lengths
,
typename
ArrangeOrder
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
Size
(),
1
>
::
type
>
typename
ArrangeOrder
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
Size
(),
1
>
::
type
>
__host__
__device__
constexpr
auto
make_cluster_descriptor_v2
(
__host__
__device__
constexpr
auto
make_cluster_descriptor_v2
(
...
@@ -68,7 +28,6 @@ __host__ __device__ constexpr auto make_cluster_descriptor_v2(
...
@@ -68,7 +28,6 @@ __host__ __device__ constexpr auto make_cluster_descriptor_v2(
return
make_single_stage_tensor_adaptor
(
return
make_single_stage_tensor_adaptor
(
make_tuple
(
transform
),
make_tuple
(
low_dim_old_top_ids
),
make_tuple
(
up_dim_new_top_ids
));
make_tuple
(
transform
),
make_tuple
(
low_dim_old_top_ids
),
make_tuple
(
up_dim_new_top_ids
));
}
}
#endif
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_description/dimension.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_DIMENSION_HPP
#define CK_DIMENSION_HPP
#include "common_header.hpp"
namespace
ck
{
template
<
index_t
Length
,
index_t
Stride
>
struct
NativeDimension
{
__host__
__device__
static
constexpr
auto
GetLength
()
{
return
Number
<
Length
>
{};
}
__host__
__device__
static
constexpr
auto
GetStride
()
{
return
Number
<
Stride
>
{};
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_description/multi_index_transform.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_MULTI_INDEX_TRANSFORM_HPP
#define CK_MULTI_INDEX_TRANSFORM_HPP
#include "common_header.hpp"
#include "multi_index.hpp"
namespace
ck
{
template
<
index_t
Length
>
struct
PassThrough
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
1
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
1
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Sequence
<
Length
>
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
return
idx_up
;
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
/* idx_low_old */
)
{
return
idx_up_diff
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
};
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
// necessary
// However, the check will be skipped if SkipIsValidCheck is set to true by user
// LowerLengths: Sequence<...>
template
<
typename
LowerLengths
,
typename
LeftPads
,
typename
RightPads
,
bool
SkipIsValidCheck
=
false
>
struct
Pad
{
static
constexpr
index_t
nDim
=
LowerLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
nDim
>
;
using
UpperIndex
=
MultiIndex
<
nDim
>
;
__host__
__device__
constexpr
Pad
()
{
static_assert
(
LowerLengths
::
GetSize
()
==
nDim
&&
LeftPads
::
GetSize
()
==
nDim
&&
RightPads
::
GetSize
()
==
nDim
,
"wrong! # of dimensions not consistent"
);
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
LowerLengths
{}
+
LeftPads
{}
+
RightPads
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
return
idx_up
-
LeftPads
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
/* idx_low_old */
)
{
return
idx_up_diff
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
// skip valid check if user request it
if
(
SkipIsValidCheck
)
{
return
true
;
}
bool
flag
=
true
;
for
(
index_t
i
=
0
;
i
<
nDim
;
++
i
)
{
flag
=
flag
&&
LeftPads
::
At
(
i
)
==
0
&&
RightPads
::
At
(
i
)
==
0
;
}
return
flag
;
}
};
// LowerLengths: Sequence<...>
// SliceBegins: Sequence<...>
// SliceEnds: Sequence<...>
template
<
typename
LowerLengths
,
typename
SliceBegins
,
typename
SliceEnds
>
struct
Slice
{
static
constexpr
index_t
nDim
=
LowerLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
nDim
>
;
using
UpperIndex
=
MultiIndex
<
nDim
>
;
__host__
__device__
constexpr
Slice
()
{
static_assert
(
LowerLengths
::
GetSize
()
==
nDim
&&
SliceBegins
::
GetSize
()
==
nDim
&&
SliceEnds
::
GetSize
()
==
nDim
,
"wrong! # of dimensions not consistent"
);
#if 0
// TODO: would not compile, error on constexpr
static_for<0, nDim, 1>{}([&](auto idim) {
static_assert(SliceBegins::At(idim) <= SliceEnds::At(idim) &&
SliceBegins::At(idim) >= 0 &&
SliceEnds::At(idim) <= LowerLengths::At(idim),
"wrong! Slice config is wrong");
});
#endif
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
SliceEnds
{}
-
SliceBegins
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
return
idx_up
+
SliceBegins
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
/* idx_low_old */
)
{
return
idx_up_diff
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
};
// LowerLengths: Sequence<...>
template
<
typename
LowerLengths
>
struct
Merge
{
static
constexpr
index_t
nDimLow
=
LowerLengths
::
Size
();
static
constexpr
index_t
nDimUp
=
1
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Sequence
<
reduce_on_sequence
(
LowerLengths
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
>
{};
}
// emulate constexpr lambda
template
<
typename
PseudoLowStrides
>
struct
lambda_CalculateLowerIndex
{
index_t
&
itmp
;
LowerIndex
&
idx_low
;
__host__
__device__
constexpr
lambda_CalculateLowerIndex
(
index_t
&
itmp_
,
LowerIndex
&
idx_low_
)
:
itmp
(
itmp_
),
idx_low
(
idx_low_
)
{
}
template
<
typename
IDim
>
__host__
__device__
constexpr
void
operator
()(
IDim
idim
)
const
{
constexpr
index_t
stride
=
PseudoLowStrides
::
At
(
idim
);
idx_low
(
idim
)
=
itmp
/
stride
;
itmp
-=
idx_low
[
idim
]
*
stride
;
}
};
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
LowerIndex
idx_low
;
index_t
itmp
=
idx_up
[
Number
<
0
>
{}];
constexpr
auto
pseudo_low_strides
=
reverse_inclusive_scan_sequence
(
LowerLengths
::
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
.
PushBack
(
Number
<
1
>
{});
static_for
<
0
,
nDimLow
-
1
,
1
>
{}(
lambda_CalculateLowerIndex
<
decltype
(
pseudo_low_strides
)
>
(
itmp
,
idx_low
));
idx_low
(
Number
<
nDimLow
-
1
>
{})
=
itmp
/
pseudo_low_strides
[
Number
<
nDimLow
-
1
>
{}];
return
idx_low
;
}
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
// If idx_up_diff is known at compile-time, many calculations can be optimized
// away by compiler
// This function assume idx_low_old is not out-of-bound
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
idx_low_old
)
{
if
(
idx_up_diff
[
Number
<
0
>
{}]
==
0
)
{
return
make_zero_multi_index
<
nDimLow
>
();
}
else
{
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
// If idx_up_diff is known at compile-time, the calculation can
// be done at compile-time. However, if idx_up_diff is only known
// at run-time, then the calculation will also be computed at
// run-time, and can be very expensive.
LowerIndex
idx_low_diff_tmp
=
CalculateLowerIndex
(
idx_up_diff
);
// find out the last low dimension that changed
index_t
last_changed_low_dim
=
0
;
static_for
<
0
,
nDimLow
,
1
>
{}([
&
](
auto
i
)
{
if
(
idx_low_diff_tmp
[
i
]
!=
0
)
{
last_changed_low_dim
=
i
;
}
});
LowerIndex
idx_low_new
=
idx_low_old
+
idx_low_diff_tmp
;
if
(
idx_up_diff
[
Number
<
0
>
{}]
>
0
)
{
// do carry check on each low dimension in reversed order
// starting from the first digit that changed
// don't check the highest dimension
bool
carry
=
false
;
static_for
<
nDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
if
(
i
<=
last_changed_low_dim
)
{
if
(
carry
)
{
++
idx_low_new
(
i
);
}
carry
=
false
;
if
(
idx_low_new
[
i
]
>=
LowerLengths
::
At
(
i
))
{
idx_low_new
(
i
)
-=
LowerLengths
::
At
(
i
);
carry
=
true
;
}
}
});
// highest dimension, no out-of-bound check
if
(
carry
)
{
++
idx_low_new
(
Number
<
0
>
{});
}
}
else
{
// do borrow check on each low dimension in reversed order
// starting from the first digit that changed
// don't check the highest dimension
bool
borrow
=
false
;
static_for
<
nDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
if
(
i
<=
last_changed_low_dim
)
{
if
(
borrow
)
{
--
idx_low_new
(
i
);
}
borrow
=
false
;
if
(
idx_low_new
[
i
]
<
0
)
{
idx_low_new
(
i
)
+=
LowerLengths
::
At
(
i
);
borrow
=
true
;
}
}
});
// highest dimension, no out-of-bound check
if
(
borrow
)
{
--
idx_low_new
(
Number
<
0
>
{});
}
}
return
idx_low_new
-
idx_low_old
;
}
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
};
// UpperLengths: Sequence<...>
template
<
typename
UpperLengths
>
struct
UnMerge
{
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimUp
=
UpperLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpperLengths
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
LowerIndex
idx_low
=
make_multi_index
(
0
);
constexpr
auto
pseudo_up_strides
=
reverse_inclusive_scan_sequence
(
UpperLengths
::
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
.
PushBack
(
Number
<
1
>
{});
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
idim
)
{
idx_low
(
Number
<
0
>
{})
+=
idx_up
[
idim
]
*
pseudo_up_strides
[
idim
];
});
return
idx_low
;
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
/* idx_low_old */
)
{
return
CalculateLowerIndex
(
idx_up_diff
);
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
};
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
// necessary
// However, the check will be skipped if SkipIsValidCheck is set to true by user
// UpperLengths: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template
<
index_t
LowerLength
,
typename
UpperLengths
,
typename
Coefficients
,
bool
SkipIsValidCheck
=
false
>
struct
Embed
{
static
constexpr
index_t
nDimLow
=
1
;
static
constexpr
index_t
nDimUp
=
UpperLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
constexpr
Embed
()
{
static_assert
(
UpperLengths
::
GetSize
()
==
nDimUp
&&
Coefficients
::
GetSize
()
==
nDimUp
+
1
,
"wrong! # of dimensions not consistent"
);
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDimUp
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
UpperLengths
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
LowerIndex
idx_low
=
make_multi_index
(
Coefficients
{}[
Number
<
nDimUp
>
{}]);
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
i
)
{
idx_low
(
Number
<
0
>
{})
+=
idx_up
[
i
]
*
Coefficients
{}[
i
];
});
return
idx_low
;
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
/* idx_low_old */
)
{
LowerIndex
idx_low_diff
=
make_multi_index
(
0
);
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
i
)
{
idx_low_diff
(
Number
<
0
>
{})
+=
idx_up_diff
[
i
]
*
Coefficients
{}[
i
];
});
return
idx_low_diff
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
// skip valid check if user request it
if
(
SkipIsValidCheck
)
{
return
true
;
}
bool
flag
=
true
;
index_t
ncorner
=
1
;
for
(
index_t
idim
=
0
;
idim
<
nDimUp
;
++
idim
)
{
ncorner
*=
2
;
}
// loop over each corner of the upper tensor
for
(
index_t
icorner
=
0
;
icorner
<
ncorner
;
++
icorner
)
{
// generate upper index for each corner
auto
idx_up
=
make_zero_multi_index
<
nDimUp
>
();
index_t
itmp
=
icorner
;
static_for
<
nDimUp
,
0
,
-
1
>
{}([
&
](
auto
idim
)
{
auto
idim_m1
=
idim
-
Number
<
1
>
{};
idx_up
(
idim_m1
)
=
itmp
%
2
==
0
?
0
:
UpperLengths
::
At
(
idim_m1
)
-
1
;
itmp
/=
2
;
});
// calculate lower index
auto
idx_low
=
CalculateLowerIndex
(
idx_up
);
// judge if lower index is valid
flag
=
flag
&&
idx_low
[
Number
<
0
>
{}]
>=
0
&&
idx_low
[
Number
<
0
>
{}]
<
LowerLength
;
}
return
flag
;
}
};
// LowerLengths: Sequence<...>
// LowerFreezePoint: Sequence<...>
template
<
typename
LowerLengths
,
typename
LowerFreezePoint
>
struct
Freeze
{
static
constexpr
index_t
nDimLow
=
LowerLengths
::
Size
();
static
constexpr
index_t
nDimUp
=
0
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
__host__
__device__
constexpr
Freeze
()
{
// TODO: sanity check: LowerFreezePoint should be within range of LowerLengths
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDimLow
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
0
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Sequence
<>
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
/*idx_up*/
)
{
return
to_multi_index
(
LowerFreezePoint
{});
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
/* idx_up_diff */
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
/* idx_low_old */
)
{
return
make_zero_multi_index
<
nDimLow
>
();
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_description/print_tensor_descriptor.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_PRINT_TENSOR_DESCRIPTOR_HPP
#define CK_PRINT_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
namespace
ck
{
template
<
typename
...
NativeDimensions
>
__host__
__device__
void
print_tensor_descriptor
(
const
char
*
s
,
const
NativeTensorDescriptor
<
NativeDimensions
...
>&
desc
)
{
print_tensor_descriptor_impl
(
s
,
desc
.
GetLengths
(),
desc
.
GetStrides
());
}
template
<
typename
...
Ts
>
__host__
__device__
void
print_tensor_descriptor
(
const
char
*
s
,
const
TransformedTensorDescriptor
<
Ts
...
>&
desc
)
{
print_tensor_descriptor_impl
(
s
,
desc
.
GetLengths
());
}
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
__host__
__device__
void
print_tensor_descriptor_impl
(
const
char
*
s
,
Sequence
<
Lengths
...
>
,
Sequence
<
Strides
...
>
)
{
constexpr
index_t
nDim
=
sizeof
...(
Lengths
);
static_assert
(
nDim
>
0
&&
nDim
<=
12
,
"wrong!"
);
static_if
<
nDim
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u}, strides {%u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
2
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u}, strides {%u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
3
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
4
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
5
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
6
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
7
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
8
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
9
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
10
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
"%u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
11
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
"%u %u "
"%u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
static_if
<
nDim
==
12
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
"%u %u %u %u "
"%u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...,
Strides
...);
});
}
template
<
index_t
...
Lengths
>
__host__
__device__
void
print_tensor_descriptor_impl
(
const
char
*
s
,
Sequence
<
Lengths
...
>
)
{
constexpr
index_t
nDim
=
sizeof
...(
Lengths
);
static_assert
(
nDim
>
0
&&
nDim
<=
12
,
"wrong!"
);
static_if
<
nDim
==
1
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
2
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
3
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
4
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
5
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
6
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u},
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
7
>
{}(
[
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
8
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
9
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
10
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
11
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
static_if
<
nDim
==
12
>
{}([
&
](
auto
)
{
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}
\n
"
,
s
,
nDim
,
Lengths
...);
});
}
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_coordinate.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_TENSOR_COORDINATE_HPP
#define CK_TENSOR_COORDINATE_HPP
#include "common_header.hpp"
#include "dimension.hpp"
#include "multi_index_transform.hpp"
#include "tensor_descriptor.hpp"
namespace
ck
{
// A "tensor cooridnate" is an opaque object that represents a "point of location" inside a tensor
// At the bare minimun, user should be able to query the following information from a tensor
// coordinate:
// 1. Tensor descriptor
// 2. Location, represented in the form of multi-index
// 3. Location, represented in the form of the offset to the origin of the tensor
// 4. If the location is inside invalid area or not, i.e. the padding area of an implicitly padded
// tensor is considered invalid, because the padding area doesn't have any physical memory
// allocation
// A tensor cooridnate also provides following functionality:
// 1. Given step size in each dimension, update itself, or return a new tensor cooridnate, so user
// can freely move the "point of location" inside the tensor
// wrapper class for NativeTensorCoordinate and TransformedTensorCoordinate
template
<
typename
TensorDesc
>
struct
TensorCoordinate
;
// tensor coordinate for native tensor
template
<
typename
NativeTensorDesc
>
struct
NativeTensorCoordinate
{
using
type
=
NativeTensorCoordinate
;
using
tensor_desc_type
=
NativeTensorDesc
;
static
constexpr
index_t
nDim
=
tensor_desc_type
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__host__
__device__
constexpr
NativeTensorCoordinate
(
Index
idx
)
:
mIndex
(
idx
),
mOffset
(
tensor_desc_type
::
CalculateOffset
(
idx
))
{
}
template
<
typename
...
Xs
>
__host__
__device__
constexpr
NativeTensorCoordinate
(
Xs
...
xs
)
:
NativeTensorCoordinate
(
make_multi_index
(
xs
...))
{
}
template
<
index_t
...
Xs
>
__host__
__device__
constexpr
NativeTensorCoordinate
(
Sequence
<
Xs
...
>
)
:
NativeTensorCoordinate
(
make_mutli_index
(
Xs
...))
{
}
__host__
__device__
static
constexpr
auto
GetTensorDescriptor
()
{
return
tensor_desc_type
{};
}
__host__
__device__
constexpr
const
Index
&
GetUpperIndex
()
const
{
return
mIndex
;
}
__host__
__device__
constexpr
const
Index
&
GetIndex
()
const
{
return
mIndex
;
}
__host__
__device__
constexpr
const
index_t
&
GetOffset
()
const
{
return
mOffset
;
}
__host__
__device__
constexpr
type
operator
+=
(
const
Index
&
idx_diff
)
{
// mIndex is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndex
+=
idx_diff
;
mOffset
+=
tensor_desc_type
::
CalculateOffsetDiff
(
idx_diff
);
return
*
this
;
}
__host__
__device__
constexpr
type
operator
-=
(
const
Index
&
idx_diff
)
{
// mIndex is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndex
-=
idx_diff
;
mOffset
-=
tensor_desc_type
::
CalculateOffsetDiff
(
idx_diff
);
return
*
this
;
}
__host__
__device__
constexpr
type
operator
+
(
const
Index
&
idx_diff
)
const
{
type
coord
=
*
this
;
coord
+=
idx_diff
;
return
coord
;
}
__host__
__device__
constexpr
type
operator
-
(
const
Index
&
idx_diff
)
const
{
type
coord
=
*
this
;
coord
-=
idx_diff
;
return
coord
;
}
__host__
__device__
static
constexpr
index_t
CalculateOffsetDiff
(
const
Index
&
idx_diff
)
{
return
tensor_desc_type
::
CalculateOffsetDiff
(
idx_diff
);
}
// evaluated at run-time
__host__
__device__
constexpr
bool
IsUpperIndexValid
()
const
{
return
tensor_desc_type
::
IsUpperIndexValid
(
GetUpperIndex
());
}
// evaluated at run-time
__host__
__device__
constexpr
bool
IsOffsetValid
()
const
{
// For native tensor, offset is valid if upper-index is valid
return
IsUpperIndexValid
();
}
// evaluated at compile-time
__host__
__device__
static
constexpr
bool
IsOffsetValidAssumingUpperIndexIsValid
()
{
return
true
;
}
private:
// mIndex may be saved and updated, however, the value of some (or all) of its entries may
// never be used. Compiler should be able to remove these entries as well as its calculation
// as dead code.
// TODO: make sure compiler indeed remove these dead code
Index
mIndex
;
index_t
mOffset
;
};
// tensor coordinate for transformed tensor
template
<
typename
TransformedTensorDesc
>
struct
TransformedTensorCoordinate
{
using
tensor_desc_type
=
TransformedTensorDesc
;
using
LowerCoord
=
typename
TensorCoordinate
<
decltype
(
tensor_desc_type
::
GetLowerTensorDescriptor
())
>::
type
;
using
UpperCoord
=
TransformedTensorCoordinate
;
static
constexpr
index_t
nDim
=
tensor_desc_type
::
GetNumOfDimension
();
using
UpperIndex
=
MultiIndex
<
nDim
>
;
__host__
__device__
constexpr
TransformedTensorCoordinate
(
UpperIndex
idx
)
:
mIndexUp
{
idx
},
mCoordLow
{
tensor_desc_type
::
CalculateLowerIndex
(
idx
)}
{
}
template
<
typename
...
Xs
>
__host__
__device__
constexpr
TransformedTensorCoordinate
(
Xs
...
xs
)
:
TransformedTensorCoordinate
(
UpperIndex
{
xs
...})
{
}
template
<
index_t
...
Xs
>
__host__
__device__
constexpr
TransformedTensorCoordinate
(
Sequence
<
Xs
...
>
)
:
TransformedTensorCoordinate
(
UpperIndex
{
Xs
...})
{
}
__host__
__device__
static
constexpr
auto
GetTensorDescriptor
()
{
return
tensor_desc_type
{};
}
__host__
__device__
constexpr
const
LowerCoord
&
GetLowerCoordinate
()
const
{
return
mCoordLow
;
}
__host__
__device__
constexpr
const
UpperIndex
&
GetUpperIndex
()
const
{
return
mIndexUp
;
}
__host__
__device__
constexpr
const
UpperIndex
&
GetIndex
()
const
{
return
GetUpperIndex
();
}
__host__
__device__
constexpr
const
index_t
&
GetOffset
()
const
{
return
GetLowerCoordinate
().
GetOffset
();
}
__host__
__device__
constexpr
UpperCoord
operator
+=
(
const
UpperIndex
&
idx_up_diff
)
{
// For transformation of multi-index difference, not all transformation functions need to
// know the old lower-index or the old upper-index. We pass both of them to the
// transformation function. The transformation function itself decides to use them or not.
mCoordLow
+=
tensor_desc_type
::
CalculateLowerIndexDiff
(
idx_up_diff
,
GetIndex
(),
GetLowerCoordinate
().
GetIndex
());
// mIndexUp is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndexUp
+=
idx_up_diff
;
return
*
this
;
}
__host__
__device__
constexpr
UpperCoord
operator
-=
(
const
UpperIndex
&
idx_up_diff
)
{
mCoordLow
-=
tensor_desc_type
::
CalculateLowerIndexDiff
(
idx_up_diff
,
GetIndex
(),
GetLowerCoordinate
().
GetIndex
());
// mIndex is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndexUp
-=
idx_up_diff
;
return
*
this
;
}
__host__
__device__
constexpr
UpperCoord
operator
+
(
const
UpperIndex
&
idx_up_diff
)
const
{
UpperCoord
coord_up
=
*
this
;
coord_up
+=
idx_up_diff
;
return
coord_up
;
}
__host__
__device__
constexpr
UpperCoord
operator
-
(
const
UpperIndex
&
idx_up_diff
)
const
{
UpperCoord
coord_up
=
*
this
;
coord_up
-=
idx_up_diff
;
return
coord_up
;
}
// Calculate offset diff without updating tensor-coordinate
// If idx_up_diff is know at compile time, and has only non-zero entries on linear dimensions,
// then all calculation can be done at compile-time.
// TODO: this function is not compiled to expected ISA
__host__
__device__
constexpr
index_t
CalculateOffsetDiff
(
const
UpperIndex
&
idx_up_diff
)
const
{
// For transformation of multi-index difference, not all transformation functions need to
// know the old lower-index or the old upper-index. We pass both of them to the
// transformation function. The transformation function itself decides to use them or not.
const
auto
idx_low_diff
=
tensor_desc_type
::
CalculateLowerIndexDiff
(
idx_up_diff
,
GetIndex
(),
GetLowerCoordinate
().
GetIndex
());
return
GetLowerCoordinate
().
CalculateOffsetDiff
(
idx_low_diff
);
}
// evaluated at run-time
__host__
__device__
constexpr
bool
IsUpperIndexValid
()
const
{
return
tensor_desc_type
::
IsUpperIndexValid
(
GetUpperIndex
());
}
// evaluted at run-time
__host__
__device__
constexpr
bool
IsOffsetValid
()
const
{
return
IsUpperIndexValid
()
&&
GetLowerCoordinate
().
IsOffsetValid
();
}
// most evaluatation is done at comile-time
__host__
__device__
constexpr
bool
IsLowerIndexValidAssumingUpperIndexIsValid
()
const
{
return
tensor_desc_type
::
IsLowerIndexValidAssumingUpperIndexIsValid
(
GetLowerCoordinate
().
GetIndex
());
}
// most evaluatation is done at comile-time
__host__
__device__
constexpr
bool
IsOffsetValidAssumingUpperIndexIsValid
()
const
{
return
IsLowerIndexValidAssumingUpperIndexIsValid
()
&&
GetLowerCoordinate
().
IsOffsetValidAssumingUpperIndexIsValid
();
}
private:
// mIndexUp may be calculated and updated, however, the value of some (or all) of its entries
// may
// never be used. Compiler should be able to remove these entries as well as its calculation
// as dead code.
// TODO: make sure compiler indeed remove these dead code
UpperIndex
mIndexUp
;
LowerCoord
mCoordLow
;
};
template
<
typename
TensorDesc
>
struct
TensorCoordinate
{
private:
template
<
typename
...
Ts
>
__host__
__device__
static
constexpr
auto
MakeDummyTensorCoordinate
(
NativeTensorDescriptor
<
Ts
...
>
)
{
return
NativeTensorCoordinate
<
NativeTensorDescriptor
<
Ts
...
>>
(
make_zero_multi_index
<
TensorDesc
::
GetNumOfDimension
()
>
());
}
template
<
typename
...
Ts
>
__host__
__device__
static
constexpr
auto
MakeDummyTensorCoordinate
(
TransformedTensorDescriptor
<
Ts
...
>
)
{
return
TransformedTensorCoordinate
<
TransformedTensorDescriptor
<
Ts
...
>>
(
make_zero_multi_index
<
TensorDesc
::
GetNumOfDimension
()
>
());
}
public:
using
type
=
decltype
(
MakeDummyTensorCoordinate
(
TensorDesc
{}));
};
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_descriptor.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "dimension.hpp"
#include "multi_index_transform.hpp"
namespace
ck
{
// tensor descriptor for "native tensor"
// A "native tensor" is a "true" tensor that can be represented by Lengths and Strides
template
<
typename
...
NativeDimensions
>
struct
NativeTensorDescriptor
{
using
type
=
NativeTensorDescriptor
;
static
constexpr
index_t
nDim
=
sizeof
...(
NativeDimensions
);
static
constexpr
auto
mDimensions
=
make_tuple
(
NativeDimensions
{}...);
using
Index
=
MultiIndex
<
nDim
>
;
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
Number
<
nDim
>
{};
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
{
return
mDimensions
.
At
(
Number
<
IDim
>
{}).
GetLength
();
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetStride
(
Number
<
IDim
>
)
{
return
mDimensions
.
At
(
Number
<
IDim
>
{}).
GetStride
();
}
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Sequence
<
IDims
...
>
)
{
return
Sequence
<
GetLength
(
Number
<
IDims
>
{})...
>
{};
}
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetStrides
(
Sequence
<
IDims
...
>
)
{
return
Sequence
<
GetStride
(
Number
<
IDims
>
{})...
>
{};
}
template
<
index_t
IDim
,
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Number
<
IDim
>
,
Number
<
IDims
>
...)
{
return
GetLengths
(
Sequence
<
IDim
,
IDims
...
>
{});
}
template
<
index_t
IDim
,
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetStrides
(
Number
<
IDim
>
,
Number
<
IDims
>
...)
{
return
GetStrides
(
Sequence
<
IDim
,
IDims
...
>
{});
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
GetLengths
(
typename
arithmetic_sequence_gen
<
0
,
nDim
,
1
>::
type
{});
}
__host__
__device__
static
constexpr
auto
GetStrides
()
{
return
GetStrides
(
typename
arithmetic_sequence_gen
<
0
,
nDim
,
1
>::
type
{});
}
__host__
__device__
static
constexpr
index_t
GetElementSize
()
{
return
reduce_on_sequence
(
GetLengths
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
}
__host__
__device__
static
constexpr
index_t
GetElementSpace
()
{
return
reduce_on_sequence
(
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
math
::
plus
<
index_t
>
{},
Number
<
1
>
{});
}
// TODO: this cannot return constepxr because of use of lambda
__host__
__device__
static
constexpr
index_t
CalculateOffset
(
const
Index
&
idx
)
{
index_t
offset
=
0
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
offset
+=
idx
[
idim
]
*
GetStride
(
idim
);
});
return
offset
;
}
__host__
__device__
static
constexpr
index_t
CalculateOffsetDiff
(
const
Index
&
idx_diff
)
{
index_t
offset_diff
=
0
;
static_for
<
0
,
nDim
,
1
>
{}(
[
&
](
auto
idim
)
{
offset_diff
+=
idx_diff
[
idim
]
*
GetStride
(
idim
);
});
return
offset_diff
;
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
bool
IsLinearDimension
(
Number
<
IDim
>
)
{
return
true
;
}
__host__
__device__
static
constexpr
auto
GetLinearDimensionMask
()
{
return
typename
uniform_sequence_gen
<
nDim
,
1
>::
type
{};
}
__host__
__device__
static
constexpr
auto
GetNonLinearDimensionMask
()
{
return
typename
uniform_sequence_gen
<
nDim
,
0
>::
type
{};
}
__host__
__device__
static
constexpr
auto
GetNonLinearDimensions
()
{
return
Sequence
<>
{};
}
__host__
__device__
static
constexpr
auto
GetNonLinearIndependentDimensionGroups
()
{
return
Tuple
<>
{};
}
// a multi-index is valid if there is a corresponding point for it in the tensor
__host__
__device__
static
constexpr
bool
IsUpperIndexValid
(
const
Index
&
idx
)
{
bool
flag
=
true
;
for
(
index_t
i
=
0
;
i
<
nDim
;
++
i
)
{
flag
=
flag
&&
idx
[
i
]
>=
0
&&
idx
[
i
]
<
GetLengths
()[
i
];
}
return
flag
;
}
};
// Tensor descriptor for "transformed tensor"
template
<
typename
LowTensorDescriptor
,
// NativeTensorDescriptor or TransformedTensorDescriptor
typename
Transforms
,
// Tuple<MultIndexTransforms...>
typename
LowDimensionIds
,
// Tuple<Sequence<...>>
typename
UpDimensionIds
>
// Tuple<Sequence<...>>
struct
TransformedTensorDescriptor
{
using
type
=
TransformedTensorDescriptor
;
static
constexpr
index_t
nTransform
=
Transforms
::
Size
();
struct
lambda_merge_sequences
{
template
<
typename
...
Seqs
>
__host__
__device__
constexpr
auto
operator
()(
Seqs
...
seqs
)
const
{
return
merge_sequences
(
seqs
...);
}
};
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
// Here, we assume all lower-dimensions are active
// TODO: sanity-check all lower-dimension are indeed active
using
duplicated_low_active_dims
=
decltype
(
unpack
(
lambda_merge_sequences
{},
LowDimensionIds
{}));
using
low_active_dims
=
typename
sequence_unique_sort
<
duplicated_low_active_dims
,
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>::
type
;
return
low_active_dims
::
Size
();
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
using
duplicated_up_active_dims
=
decltype
(
unpack
(
lambda_merge_sequences
{},
UpDimensionIds
{}));
using
up_active_dims
=
typename
sequence_unique_sort
<
duplicated_up_active_dims
,
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>::
type
;
return
up_active_dims
::
Size
();
}
static
constexpr
index_t
nDimUp
=
GetNumOfUpperDimension
();
static
constexpr
index_t
nDimLow
=
GetNumOfLowerDimension
();
using
UpperIndex
=
MultiIndex
<
nDimUp
>
;
using
LowerIndex
=
MultiIndex
<
nDimLow
>
;
__host__
__device__
constexpr
TransformedTensorDescriptor
()
{
static_assert
(
nTransform
==
Transforms
::
Size
()
&&
nTransform
==
LowDimensionIds
::
Size
()
&&
nTransform
==
UpDimensionIds
::
Size
(),
"wrong! # of transformations not the same"
);
// sanity check:
// LowDimensionIds should include all low-dimensions,
// UpDimensionIds should include all up-dimensions
using
mingled_up_dimension_ids
=
decltype
(
unpack
(
lambda_merge_sequences
{},
UpDimensionIds
{}));
using
sorted_up_dimension_ids
=
typename
sequence_sort
<
mingled_up_dimension_ids
,
math
::
less
<
index_t
>>::
type
;
static_assert
(
sorted_up_dimension_ids
::
Size
()
==
nDimUp
&&
is_valid_sequence_map
<
sorted_up_dimension_ids
>
{},
"wrong! UpDimensionIds is not configured correctly"
);
using
mingled_low_dimension_ids
=
decltype
(
unpack
(
lambda_merge_sequences
{},
LowDimensionIds
{}));
using
sorted_low_dimension_ids
=
typename
sequence_sort
<
mingled_low_dimension_ids
,
math
::
less
<
index_t
>>::
type
;
static_assert
(
sorted_low_dimension_ids
::
Size
()
==
nDimLow
&&
is_valid_sequence_map
<
sorted_low_dimension_ids
>
{},
"wrong! LowDimensionIds is not configured correctly"
);
// TODO: sanity check: while a up-dimension could be associated with multille
// transformation, a low-dimension should be associated with only one transformation
// TODO: sanity-check: GetLowerLengths of each transform should be consistent with lengths
// of lower-tensor-descriptor
}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
GetNumOfUpperDimension
();
}
__host__
__device__
static
constexpr
auto
GetLowerTensorDescriptor
()
{
return
LowTensorDescriptor
{};
}
struct
lambda_GetUpperLengths
{
template
<
typename
Transform
>
__host__
__device__
constexpr
auto
operator
()(
const
Transform
&
tran
)
const
{
return
tran
.
GetUpperLengths
();
}
};
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
constexpr
auto
tuple_of_up_lengths
=
transform_tuples
(
lambda_GetUpperLengths
{},
Transforms
{});
constexpr
auto
mingled_up_lengths
=
unpack
(
lambda_merge_sequences
{},
tuple_of_up_lengths
);
constexpr
auto
mingled_up_dimension_ids
=
unpack
(
lambda_merge_sequences
{},
UpDimensionIds
{});
// TODO: sanity-check mingled_up_dimension_ids contain all upper-dimensions
// TODO: sanity-check mingled_up_lengths have no conflicting upper-length
// sort by upper-dimension-ids
using
sort_up_dimension_ids
=
sequence_unique_sort
<
decltype
(
mingled_up_dimension_ids
),
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>
;
// sanity-check sorted-upper-dimension-ids should be Sequence<0, 1, ... nDimUp-1>
static_assert
(
is_same
<
typename
sort_up_dimension_ids
::
type
,
typename
arithmetic_sequence_gen
<
0
,
nDimUp
,
1
>::
type
>
{},
"wrong! UpDimensionIds is not configured correctly"
);
constexpr
auto
sorted2unsorted_map
=
typename
sort_up_dimension_ids
::
sorted2unsorted_map
{};
constexpr
auto
sorted_up_lengths
=
pick_sequence_elements_by_ids
(
mingled_up_lengths
,
sorted2unsorted_map
);
return
sorted_up_lengths
;
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
GetUpperLengths
();
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetLength
(
Number
<
IDim
>
)
{
return
GetLengths
()[
IDim
];
}
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Sequence
<
IDims
...
>
)
{
return
Sequence
<
GetLength
(
Number
<
IDims
>
{})...
>
{};
}
template
<
index_t
IDim
,
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
GetLengths
(
Number
<
IDim
>
,
Number
<
IDims
>
...)
{
return
GetLengths
(
Sequence
<
IDim
,
IDims
...
>
{});
}
__host__
__device__
static
constexpr
index_t
GetElementSize
()
{
return
reduce_on_sequence
(
GetLengths
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
}
__host__
__device__
static
constexpr
index_t
GetElementSpace
()
{
// TODO: Is this the correct definition for transformed tensor?
return
GetLowerTensorDescriptor
().
GetElementSpace
();
}
// TODO: right now return value is not constexpr because use of non-constexpr lambda
__host__
__device__
static
constexpr
LowerIndex
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
LowerIndex
idx_low
;
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
tran
=
Transforms
{}.
At
(
itran
);
const
auto
idx_up_part
=
pick_container_element
(
idx_up
,
UpDimensionIds
{}.
At
(
itran
));
auto
idx_low_part
=
pick_container_element
(
idx_low
,
LowDimensionIds
{}.
At
(
itran
));
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_part
=
tran
.
CalculateLowerIndex
(
to_multi_index
(
idx_up_part
));
});
return
idx_low
;
}
// TODO: right now return value is not constexpr because use of non-constepxr lambda
__host__
__device__
static
constexpr
LowerIndex
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
idx_up_old
,
const
LowerIndex
&
idx_low_old
)
{
LowerIndex
idx_low_diff
;
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
tran
=
Transforms
{}.
At
(
itran
);
const
auto
idx_up_diff_part
=
pick_container_element
(
idx_up_diff
,
UpDimensionIds
{}.
At
(
itran
));
const
auto
idx_up_old_part
=
pick_container_element
(
idx_up_old
,
UpDimensionIds
{}.
At
(
itran
));
const
auto
idx_low_old_part
=
pick_container_element
(
idx_low_old
,
LowDimensionIds
{}.
At
(
itran
));
auto
idx_low_diff_part
=
pick_container_element
(
idx_low_diff
,
LowDimensionIds
{}.
At
(
itran
));
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_diff_part
=
tran
.
CalculateLowerIndexDiff
(
to_multi_index
(
idx_up_diff_part
),
to_multi_index
(
idx_up_old_part
),
to_multi_index
(
idx_low_old_part
));
});
return
idx_low_diff
;
}
__host__
__device__
static
constexpr
index_t
CalculateOffset
(
const
UpperIndex
&
idx_up
)
{
return
GetLowerTensorDescriptor
().
CalculateOffset
(
CalculateLowerIndex
(
idx_up
));
}
struct
lambda_sequence_logical_and
{
template
<
typename
...
Seqs
>
__host__
__device__
constexpr
auto
operator
()(
Seqs
...)
const
{
return
typename
sequence_reduce
<
logical_and
<
index_t
>
,
Seqs
...
>::
type
{};
}
};
template
<
typename
T
>
struct
lambda_is_true
{
__host__
__device__
constexpr
auto
operator
()(
const
T
&
x
)
const
{
// TODO: remove static_cast once Sequence can take bool as entries
return
static_cast
<
bool
>
(
x
)
==
true
;
}
};
struct
lambda_get_linear_dimension_mask_of_single_tranform
{
// check only one transform at a time
template
<
typename
Transform
,
typename
LowDimensionId
,
typename
UpDimensionId
>
__host__
__device__
constexpr
auto
operator
()(
Transform
,
LowDimensionId
,
UpDimensionId
)
const
{
// judge if transformation is linear
constexpr
bool
is_linear_transform
=
Transform
::
IsLinearTransform
();
// judge if all lower dimension are linear
constexpr
bool
are_all_low_dim_linear
=
sequence_all_of
(
pick_sequence_elements_by_ids
(
GetLowerTensorDescriptor
().
GetLinearDimensionMask
(),
LowDimensionId
{}),
lambda_is_true
<
index_t
>
{});
// create linear mask for upper dimensions
constexpr
bool
are_up_dim_linear
=
is_linear_transform
&&
are_all_low_dim_linear
;
constexpr
auto
mask_of_up_linear_dims
=
modify_sequence_elements_by_ids
(
typename
uniform_sequence_gen
<
nDimUp
,
1
>::
type
{},
typename
uniform_sequence_gen
<
UpDimensionId
::
Size
(),
are_up_dim_linear
>::
type
{},
UpDimensionId
{});
return
mask_of_up_linear_dims
;
}
};
// TODO: this is a hack, transform_tuples() doesn't compile, would complain about constexpr
template
<
typename
F
,
typename
X
,
typename
Y
,
typename
Z
,
index_t
...
Is
>
__host__
__device__
static
constexpr
auto
dummy_transform_tuples_impl
(
F
f
,
X
x
,
Y
y
,
Z
z
,
Sequence
<
Is
...
>
)
{
return
make_tuple
(
f
(
x
.
At
(
Number
<
Is
>
{}),
y
.
At
(
Number
<
Is
>
{}),
z
.
At
(
Number
<
Is
>
{}))...);
}
__host__
__device__
static
constexpr
auto
GetLinearDimensionMask
()
{
#if 0
// create tuple of linear dimension masks, for all transformations
// TODO: this doesn't compile, because transform_tuples() complain about constexpr
constexpr auto tuple_of_linear_dimension_mask =
transform_tuples(lambda_get_linear_dimension_mask_of_single_tranform{},
Transforms{},
LowDimensionIds{},
UpDimensionIds{});
#else
// create tuple of linear dimension masks, for all transformations
// TODO: this is a hack
constexpr
auto
tuple_of_linear_dimension_mask
=
dummy_transform_tuples_impl
(
lambda_get_linear_dimension_mask_of_single_tranform
{},
Transforms
{},
LowDimensionIds
{},
UpDimensionIds
{},
typename
arithmetic_sequence_gen
<
0
,
Transforms
::
Size
(),
1
>::
type
{});
#endif
// reduce tuple of masks into one mask
constexpr
auto
linear_dimension_mask
=
unpack
(
lambda_sequence_logical_and
{},
tuple_of_linear_dimension_mask
);
return
linear_dimension_mask
;
}
__host__
__device__
static
constexpr
auto
GetNonLinearDimensionMask
()
{
return
GetLinearDimensionMask
().
Transform
(
logical_not
<
index_t
>
{});
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
bool
IsLinearDimension
(
Number
<
IDim
>
)
{
return
GetLinearDimensionMask
().
At
(
Number
<
IDim
>
{});
}
__host__
__device__
static
constexpr
auto
GetLinearDimensions
()
{
constexpr
auto
linear_dimension_mask
=
GetLinearDimensionMask
();
return
pick_sequence_elements_by_mask
(
typename
arithmetic_sequence_gen
<
0
,
nDimUp
,
1
>::
type
{},
linear_dimension_mask
);
}
__host__
__device__
static
constexpr
auto
GetNonLinearDimensions
()
{
constexpr
auto
nonlinear_dimension_mask
=
GetNonLinearDimensionMask
();
return
pick_sequence_elements_by_mask
(
typename
arithmetic_sequence_gen
<
0
,
nDimUp
,
1
>::
type
{},
nonlinear_dimension_mask
);
}
#if 0
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
{
// TODO: not implemented
}
#endif
// a multi-index is valid if there is a corresponding point for it in the tensor
__host__
__device__
constexpr
bool
IsUpperIndexValid
(
const
UpperIndex
&
idx_up
)
const
{
bool
flag
=
true
;
for
(
index_t
i
=
0
;
i
<
nDimUp
;
++
i
)
{
flag
=
flag
&&
idx_up
[
i
]
>=
0
&&
idx_up
[
i
]
<
GetLengths
()[
i
];
}
return
flag
;
}
// this function is for optimization purpose, it's called by tensor coordinate
// this function tells you: If a lower-index is valid or not, assuming upper index is valid
__host__
__device__
static
constexpr
bool
IsLowerIndexValidAssumingUpperIndexIsValid
(
const
LowerIndex
&
idx_low
)
{
bool
flag
=
true
;
static_for
<
0
,
nTransform
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
tran
=
Transforms
{}.
At
(
itran
);
// check a indtransformation if it does not always has a valid mapping
constexpr
bool
is_valid_up_always_mapped_to_valid_low
=
decltype
(
tran
)
::
IsValidUpperIndexAlwaysMappedToValidLowerIndex
();
if
(
!
is_valid_up_always_mapped_to_valid_low
)
{
constexpr
auto
low_dims_part
=
LowDimensionIds
{}.
At
(
itran
);
constexpr
auto
low_lengths_part
=
GetLowerTensorDescriptor
().
GetLengths
(
low_dims_part
);
const
auto
idx_low_part
=
to_multi_index
(
pick_container_element
(
idx_low
,
low_dims_part
));
static_for
<
0
,
decltype
(
low_dims_part
)
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
flag
=
flag
&&
idx_low_part
[
i
]
>=
0
&&
idx_low_part
[
i
]
<
low_lengths_part
[
i
];
});
}
});
return
flag
;
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
namespace
ck
{
template
<
typename
Lengths
>
__host__
__device__
constexpr
auto
calculate_tensor_strides_packed
(
Lengths
)
{
return
reverse_inclusive_scan_sequence
(
Lengths
{}.
PopFront
(),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
.
PushBack
(
Number
<
1
>
{});
}
template
<
typename
Lengths
,
index_t
Align
>
__host__
__device__
constexpr
auto
calculate_tensor_strides_aligned
(
Lengths
,
Number
<
Align
>
)
{
constexpr
index_t
L_back_align
=
Align
*
math
::
integer_divide_ceiler
<
index_t
>
{}(
Lengths
{}.
Back
(),
Align
);
return
calculate_tensor_strides_packed
(
Lengths
{}.
Modify
(
Number
<
Lengths
{}.
GetSize
()
-
1
>
{},
Number
<
L_back_align
>
{}));
}
template
<
index_t
...
Lengths
,
index_t
...
Strides
>
__host__
__device__
constexpr
auto
make_native_tensor_descriptor
(
Sequence
<
Lengths
...
>
,
Sequence
<
Strides
...
>
)
{
return
NativeTensorDescriptor
<
NativeDimension
<
Lengths
,
Strides
>
...
>
{};
}
template
<
typename
Lengths
>
__host__
__device__
constexpr
auto
make_native_tensor_descriptor_packed
(
Lengths
)
{
constexpr
auto
strides
=
calculate_tensor_strides_packed
(
Lengths
{});
return
make_native_tensor_descriptor
(
Lengths
{},
strides
);
}
template
<
typename
Lengths
,
index_t
Align
>
__host__
__device__
constexpr
auto
make_native_tensor_descriptor_aligned
(
Lengths
,
Number
<
Align
>
)
{
constexpr
auto
strides
=
calculate_tensor_strides_aligned
(
Lengths
{},
Number
<
Align
>
{});
return
make_native_tensor_descriptor
(
Lengths
{},
strides
);
}
template
<
typename
LowTensorDescriptor
,
typename
Transforms
,
typename
LowDimensionIds
,
typename
UpDimensionIds
>
__host__
__device__
constexpr
auto
transform_tensor_descriptor
(
LowTensorDescriptor
,
Transforms
,
LowDimensionIds
,
UpDimensionIds
)
{
return
TransformedTensorDescriptor
<
LowTensorDescriptor
,
Transforms
,
LowDimensionIds
,
UpDimensionIds
>
{};
}
template
<
typename
LowerTensorDescriptor
,
index_t
...
LowerLengths
,
index_t
...
LowerDimensionIds
,
index_t
...
UpperDimensionIds
>
__host__
__device__
constexpr
auto
reorder_transformed_tensor_descriptor_impl
(
LowerTensorDescriptor
,
Sequence
<
LowerLengths
...
>
,
Sequence
<
LowerDimensionIds
...
>
,
Sequence
<
UpperDimensionIds
...
>
)
{
return
TransformedTensorDescriptor
<
LowerTensorDescriptor
,
Tuple
<
PassThrough
<
LowerLengths
>
...
>
,
Tuple
<
Sequence
<
LowerDimensionIds
>
...
>
,
Tuple
<
Sequence
<
UpperDimensionIds
>
...
>>
{};
}
// reorder a NativeTensorDescriptor
template
<
typename
...
Ts
,
typename
MapLower2Upper
>
__host__
__device__
constexpr
auto
reorder_tensor_descriptor_given_lower2upper
(
NativeTensorDescriptor
<
Ts
...
>
,
MapLower2Upper
)
{
static_assert
(
is_valid_sequence_map
<
MapLower2Upper
>
{},
"wrong! MapLower2Upper is not a valid map"
);
constexpr
auto
old_desc
=
NativeTensorDescriptor
<
Ts
...
>
{};
static_assert
(
old_desc
.
GetNumOfDimension
()
==
MapLower2Upper
::
Size
(),
"wrong!"
);
constexpr
auto
new_lengths
=
old_desc
.
GetLengths
().
ReorderGivenOld2New
(
MapLower2Upper
{});
constexpr
auto
new_strides
=
old_desc
.
GetStrides
().
ReorderGivenOld2New
(
MapLower2Upper
{});
return
make_native_tensor_descriptor
(
new_lengths
,
new_strides
);
}
// reorder a TransformedTensorDescriptor
template
<
typename
...
Ts
,
typename
MapLower2Upper
>
__host__
__device__
constexpr
auto
reorder_tensor_descriptor_given_lower2upper
(
TransformedTensorDescriptor
<
Ts
...
>
,
MapLower2Upper
)
{
static_assert
(
is_valid_sequence_map
<
MapLower2Upper
>
{},
"wrong! MapLower2Upper is not a valid map"
);
constexpr
auto
low_desc
=
TransformedTensorDescriptor
<
Ts
...
>
{};
static_assert
(
low_desc
.
GetNumOfDimension
()
==
MapLower2Upper
::
Size
(),
"wrong!"
);
return
reorder_transformed_tensor_descriptor_impl
(
low_desc
,
low_desc
.
GetLengths
(),
typename
arithmetic_sequence_gen
<
0
,
low_desc
.
GetNumOfDimension
(),
1
>::
type
{},
MapLower2Upper
{});
}
template
<
typename
LowerTensorDescriptor
,
typename
MapUpper2Lower
>
__host__
__device__
constexpr
auto
reorder_tensor_descriptor_given_upper2lower
(
LowerTensorDescriptor
,
MapUpper2Lower
)
{
return
reorder_tensor_descriptor_given_lower2upper
(
LowerTensorDescriptor
{},
typename
sequence_map_inverse
<
MapUpper2Lower
>::
type
{});
}
template
<
typename
Lengths
,
typename
Strides
>
__host__
__device__
constexpr
bool
are_dimensions_unfoldable
(
Lengths
,
Strides
)
{
static_assert
(
Lengths
::
Size
()
==
Strides
::
Size
(),
"wrong!"
);
bool
flag
=
true
;
for
(
index_t
i
=
0
;
i
<
Lengths
::
Size
()
-
1
;
++
i
)
{
flag
=
flag
&&
Strides
::
At
(
i
)
==
Strides
::
At
(
i
+
1
)
*
Lengths
::
At
(
i
+
1
);
}
return
flag
;
}
// unfold only support NativeTennsorDescriptor, for now
template
<
index_t
FirstUnfoldDim
,
index_t
LastUnfoldDim
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
unfold_tensor_descriptor
(
NativeTensorDescriptor
<
Ts
...
>
desc
,
Number
<
FirstUnfoldDim
>
,
Number
<
LastUnfoldDim
>
)
{
constexpr
index_t
nDim
=
desc
.
GetNumOfDimension
();
static_assert
(
FirstUnfoldDim
>=
0
&&
LastUnfoldDim
<
nDim
&&
FirstUnfoldDim
<=
LastUnfoldDim
,
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!"
);
// left and right
constexpr
auto
left
=
typename
arithmetic_sequence_gen
<
0
,
FirstUnfoldDim
,
1
>::
type
{};
constexpr
auto
middle
=
typename
arithmetic_sequence_gen
<
FirstUnfoldDim
,
LastUnfoldDim
+
1
,
1
>::
type
{};
constexpr
auto
right
=
typename
arithmetic_sequence_gen
<
LastUnfoldDim
+
1
,
nDim
,
1
>::
type
{};
// sanity-check if unfold-able
static_assert
(
are_dimensions_unfoldable
(
desc
.
GetLengths
(
middle
),
desc
.
GetStrides
(
middle
)),
"wrong! not unfold-able"
);
// unfolded length, stride
constexpr
index_t
unfold_length
=
reduce_on_sequence
(
desc
.
GetLengths
(
middle
),
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
constexpr
index_t
unfold_stride
=
desc
.
GetStride
(
Number
<
LastUnfoldDim
>
{});
// new lengths, strides
constexpr
auto
new_lengths
=
desc
.
GetLengths
(
left
).
PushBack
(
Number
<
unfold_length
>
{}).
PushBack
(
desc
.
GetLengths
(
right
));
constexpr
auto
new_strides
=
desc
.
GetStrides
(
left
).
PushBack
(
Number
<
unfold_stride
>
{}).
PushBack
(
desc
.
GetStrides
(
right
));
return
make_native_tensor_descriptor
(
new_lengths
,
new_strides
);
}
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_batched_gemm.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_BLOCKWISE_BATCHED_GEMM_HPP
#define CK_BLOCKWISE_BATCHED_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp"
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 1
#endif
namespace
ck
{
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
index_t
BlockMatrixStrideA
,
index_t
BlockMatrixStrideB
,
index_t
ThreadMatrixStrideC
,
index_t
BatchSize
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
KPerThreadLoop
,
index_t
BatchPerThread
,
index_t
DataPerReadA
,
index_t
DataPerReadB
>
struct
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
index_t
mMyThreadOffsetA
=
0
;
index_t
mMyThreadOffsetB
=
0
;
struct
MatrixIndex
{
index_t
batch
;
index_t
row
;
index_t
col
;
};
__device__
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
()
{
static_assert
(
BatchSize
%
BatchPerThread
==
0
,
"wrong! BatchSize is not dividable by BatchPerThread"
);
constexpr
index_t
BatchThreadWork
=
BatchSize
/
BatchPerThread
;
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
static_assert
(
BlockSize
==
BatchThreadWork
*
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
// A is transposed
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
((
MPerThread
%
MPerThreadSubC
==
0
)
&&
(
NPerThread
%
NPerThreadSubC
==
0
),
"wrong! Cannot evenly divide thread work among repeat
\n
"
);
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
((
M
%
MRepeat
==
0
)
&&
(
N
%
NRepeat
==
0
),
"wrong! Cannot evenly divide work among repeat
\n
"
);
constexpr
index_t
MPerLevel1Cluster
=
M
/
MRepeat
;
constexpr
index_t
NPerLevel1Cluster
=
N
/
NRepeat
;
static_assert
((
MPerLevel1Cluster
%
MLevel1Cluster
==
0
)
&&
(
NPerLevel1Cluster
%
NLevel1Cluster
==
0
),
"wrong! Cannot evenly divide work among Level1Cluster
\n
"
);
constexpr
index_t
MPerLevel0Cluster
=
MPerLevel1Cluster
/
MLevel1Cluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerLevel1Cluster
/
NLevel1Cluster
;
static_assert
((
MPerLevel0Cluster
%
MLevel0Cluster
==
0
)
&&
(
NPerLevel0Cluster
%
NLevel0Cluster
==
0
),
"wrong! Cannot evenly divide work among Level0Cluster
\n
"
);
static_assert
((
MPerThreadSubC
==
MPerLevel0Cluster
/
MLevel0Cluster
)
&&
(
NPerThreadSubC
==
NPerLevel0Cluster
/
NLevel0Cluster
),
"wrong! thread work size is wrong
\n
"
);
const
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideA
+
a_block_mtx
.
GetOffsetFromMultiIndex
(
0
,
c_thread_mtx_index
.
row
);
mMyThreadOffsetB
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideB
+
b_block_mtx
.
GetOffsetFromMultiIndex
(
0
,
c_thread_mtx_index
.
col
);
}
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
const
{
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
constexpr
index_t
ThreadPerLevel0Cluster
=
MLevel0Cluster
*
NLevel0Cluster
;
index_t
batch_work_id
=
thread_id
/
ThreadPerLevel1Cluster
;
index_t
cluster_id
=
thread_id
-
batch_work_id
*
ThreadPerLevel1Cluster
;
index_t
level1_id
=
cluster_id
/
ThreadPerLevel0Cluster
;
index_t
level1_m_id
=
level1_id
/
NLevel1Cluster
;
index_t
level1_n_id
=
level1_id
%
NLevel1Cluster
;
index_t
level0_id
=
cluster_id
%
ThreadPerLevel0Cluster
;
index_t
level0_m_id
=
level0_id
/
NLevel0Cluster
;
index_t
level0_n_id
=
level0_id
%
NLevel0Cluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0Cluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0Cluster
;
return
MatrixIndex
{
batch_work_id
*
BatchPerThread
,
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
// this should be optimized away because input will be known at compile time
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
index_t
batch_in_c
,
index_t
m_in_c
,
index_t
n_in_c
)
{
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
index_t
m_repeat
=
m_in_c
/
MPerThreadSubC
;
index_t
n_repeat
=
n_in_c
/
NPerThreadSubC
;
index_t
m_in_sub_c
=
m_in_c
%
MPerThreadSubC
;
index_t
n_in_sub_c
=
n_in_c
%
NPerThreadSubC
;
return
MatrixIndex
{
batch_in_c
,
m_repeat
*
MPerLevel1Cluster
+
m_in_sub_c
,
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run_source
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// loop over k
#pragma unroll
for
(
index_t
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// loop over batch
#pragma unroll
for
(
index_t
ib
=
0
;
ib
<
BatchPerThread
;
++
ib
)
{
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
or
ib
==
0
)
{
#pragma unroll
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
a_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
ib
*
BlockMatrixStrideA
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadA
>
{});
}
}
if
(
BlockMatrixStrideB
!=
0
or
ib
==
0
)
{
#pragma unroll
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
b_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
ib
*
BlockMatrixStrideB
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
}
}
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
ib
*
ThreadMatrixStrideC
);
}
}
}
#if CK_USE_AMD_INLINE_ASM
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run_amd_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
// assertion for inline asm
static_assert
(
is_same
<
FloatA
,
float
>
{}
&&
is_same
<
FloatB
,
float
>
{}
&&
is_same
<
FloatC
,
float
>
{},
"Run_amd_asm only deal with float
\n
"
);
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
MPerThread
==
8
&&
NPerThread
==
8
,
"Run_amd_asm cannot deal with this GEMM shape yet
\n
"
);
static_assert
(
DataPerReadA
==
4
&&
DataPerReadB
==
4
,
"Run_amd_asm only do float4 read
\n
"
);
static_assert
(
BlockMatrixStrideA
==
0
&&
BatchPerThread
==
1
,
"Run_amd_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == "
"1 for now
\n
"
);
using
Float4
=
vector_type
<
float
,
4
>::
type
;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
b_block_mtx
.
GetOffsetFromMultiIndex
(
0
,
NPerLevel1Cluster
)
+
mMyThreadOffsetB
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
a_block_mtx
.
GetOffsetFromMultiIndex
(
0
,
MPerLevel1Cluster
)
+
mMyThreadOffsetA
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
#pragma unroll
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
a_block_mtx
.
GetOffsetFromMultiIndex
(
k
,
0
)
+
mMyThreadOffsetA
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
b_block_mtx
.
GetOffsetFromMultiIndex
(
k
,
0
)
+
mMyThreadOffsetB
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
b_block_mtx
.
GetOffsetFromMultiIndex
(
k
,
NPerLevel1Cluster
)
+
mMyThreadOffsetB
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
a_block_mtx
.
GetOffsetFromMultiIndex
(
k
,
MPerLevel1Cluster
)
+
mMyThreadOffsetA
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
}
#endif
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
{
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
Run_amd_asm
(
p_a_block
,
p_b_block
,
p_c_thread
);
#else
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
#endif
}
template
<
class
BlockMatrixC
,
index_t
BlockMatrixStrideC
,
class
FloatC
>
__device__
void
CopyThreadMatrixCToBlockMatrixC
(
const
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_block
)
const
{
constexpr
auto
c_block_mtx
=
BlockMatrixC
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
auto
c_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
const
auto
c_thread_mtx_begin
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
c_thread_offset
=
c_thread_mtx_begin
.
batch
*
BlockMatrixStrideC
+
c_block_mtx
.
GetOffsetFromMultiIndex
(
c_thread_mtx_begin
.
row
,
c_thread_mtx_begin
.
col
);
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
c_thread_sub_mtx
,
p_c_thread
+
c_thread_sub_mtx
.
GetOffsetFromMultiIndex
(
m_repeat
*
MPerLevel1Cluster
,
n_repeat
*
NPerLevel1Cluster
),
c_block_mtx
,
p_c_block
+
c_block_mtx
.
GetOffsetFromMultiIndex
(
m_repeat
*
MPerLevel1Cluster
,
n_repeat
*
NPerLevel1Cluster
)
+
c_thread_offset
,
c_thread_sub_mtx
.
GetLengths
());
}
}
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_BLOCKWISE_GEMM_HPP
#define CK_BLOCKWISE_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp"
namespace
ck
{
// blockwise GEMM: C += transpose(A) * B
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template
<
index_t
BlockSize
,
typename
BlockMatrixA
,
typename
BlockMatrixB
,
typename
ThreadMatrixC
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
KPerThreadLoop
,
index_t
MLevel0ThreadCluster
,
index_t
NLevel0ThreadCluster
,
index_t
MLevel1ThreadCluster
,
index_t
NLevel1ThreadCluster
,
index_t
ThreadGemmADataPerRead_M
,
index_t
ThreadGemmBDataPerRead_N
>
struct
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
index_t
mMyThreadOffsetA
;
index_t
mMyThreadOffsetB
;
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
()
{
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
*
MLevel1ThreadCluster
*
NLevel1ThreadCluster
;
static_assert
(
BlockSize
==
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
BlockMatrixA
::
NRow
()
==
BlockMatrixB
::
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
static_assert
(
M
%
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
)
==
0
&&
N
%
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
)
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
static_assert
(
is_same
<
decltype
(
ThreadMatrixC
::
GetLengths
()),
decltype
(
GetThreadMatrixCLengths
())
>
{},
"wrong! ThreadMatrixC lengths is wrong"
);
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
BlockMatrixA
::
GetOffsetFromMultiIndex
(
0
,
c_thread_mtx_index
.
row
);
mMyThreadOffsetB
=
BlockMatrixB
::
GetOffsetFromMultiIndex
(
0
,
c_thread_mtx_index
.
col
);
}
__device__
static
constexpr
auto
GetThreadMatrixCLengths
()
{
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
constexpr
index_t
MRepeat
=
M
/
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
);
constexpr
index_t
NRepeat
=
N
/
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
);
return
Sequence
<
MRepeat
*
MPerThreadSubC
,
NRepeat
*
NPerThreadSubC
>
{};
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
{
constexpr
index_t
ThreadPerLevel0Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
;
index_t
level1_id
=
thread_id
/
ThreadPerLevel0Cluster
;
index_t
level1_m_id
=
level1_id
/
NLevel1ThreadCluster
;
index_t
level1_n_id
=
level1_id
%
NLevel1ThreadCluster
;
index_t
level0_id
=
thread_id
%
ThreadPerLevel0Cluster
;
index_t
level0_m_id
=
level0_id
/
NLevel0ThreadCluster
;
index_t
level0_n_id
=
level0_id
%
NLevel0ThreadCluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
;
return
MatrixIndex
{
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
ThreadGemmADataPerRead_M
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
ThreadGemmBDataPerRead_N
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmTransANormalBNormalC
<
decltype
(
a_thread_mtx
),
decltype
(
b_thread_mtx
),
decltype
(
c_thread_mtx
)
>
{};
#pragma unroll
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
#pragma unroll
// read A
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
a_thread_copy
.
Run
(
p_a_block
+
a_block_mtx
.
CalculateOffset
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
mMyThreadOffsetA
,
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
m_repeat
*
MPerThreadSubC
));
}
#pragma unroll
// read B
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
b_thread_copy
.
Run
(
p_b_block
+
b_block_mtx
.
CalculateOffset
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
mMyThreadOffsetB
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
n_repeat
*
NPerThreadSubC
));
}
// C += A * B
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
}
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_pipelined_2x2
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
(
MRepeat
==
2
&&
NRepeat
==
2
,
"wrong! inline asm cannot deal with this GEMM config yet"
);
// thread A, B
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub
constexpr
auto
a_thread_sub_mtx
=
a_thread_mtx
.
MakeSubMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{});
constexpr
auto
b_thread_sub_mtx
=
b_thread_mtx
.
MakeSubMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{});
// thread C-sub
constexpr
auto
c_thread_sub_mtx
=
ThreadMatrixC
::
MakeSubMatrixDescriptor
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
ThreadGemmADataPerRead_M
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
ThreadGemmBDataPerRead_N
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmTransANormalBNormalC
<
decltype
(
a_thread_sub_mtx
),
decltype
(
b_thread_sub_mtx
),
decltype
(
c_thread_sub_mtx
)
>
{};
const
FloatA
*
p_a_block_off
=
p_a_block
+
mMyThreadOffsetA
;
const
FloatB
*
p_b_block_off
=
p_b_block
+
mMyThreadOffsetB
;
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
,
p_a_thread
);
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
,
p_b_thread
);
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
0
,
NPerLevel1Cluster
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
));
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
0
,
MPerLevel1Cluster
),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
0
,
NPerThreadSubC
));
#pragma unroll
// loop over rest of k
for
(
index_t
k
=
KPerThreadLoop
;
k
<
K
;
k
+=
KPerThreadLoop
)
{
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
k
,
0
),
p_a_thread
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
p_b_thread
,
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
0
));
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
k
,
0
),
p_b_thread
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
NPerThreadSubC
));
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
k
,
NPerLevel1Cluster
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
));
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
k
,
MPerLevel1Cluster
),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
0
,
NPerThreadSubC
));
}
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
p_b_thread
,
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
0
));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
NPerThreadSubC
));
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr
index_t
MPerThread
=
ThreadMatrixC
::
NRow
();
constexpr
index_t
NPerThread
=
ThreadMatrixC
::
NCol
();
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
if
constexpr
(
MRepeat
==
2
&&
NRepeat
==
2
)
{
Run_pipelined_2x2
(
p_a_block
,
p_b_block
,
p_c_thread
);
}
else
{
Run_naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
}
#else
Run_naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
#endif
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#define CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_coordinate.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace
ck
{
// This blockwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// The dimension of vector access can be different for src and dst.
// The dimension access order can be different for src and dst.
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
// BlockSize can be equal or larger than ThreadCluster size, which means some threads may not do
// threadwise copy
template
<
index_t
BlockSize
,
typename
BlockSrcDesc
,
typename
BlockDstDesc
,
typename
BlockSliceLengths
,
typename
ThreadSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectoReadDim
,
index_t
DstVectorWriteDim
,
index_t
SrcDataPerRead
,
index_t
DstDataPerWrite
,
AddressSpace
SrcAddressSpace
=
AddressSpace
::
Generic
,
AddressSpace
ThreadBufferAddressSpace
=
AddressSpace
::
Generic
,
AddressSpace
DstAddressSpace
=
AddressSpace
::
Generic
,
InMemoryDataOperation
DstInMemOp
=
InMemoryDataOperation
::
Set
,
index_t
SrcDataStride
=
1
,
index_t
DstDataStride
=
1
>
struct
BlockwiseGenericTensorSliceCopy_v4
{
static
constexpr
index_t
nDim
=
BlockSrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseGenericTensorSliceCopy_v4
(
const
Index
&
src_block_slice_origin
,
const
Index
&
dst_block_slice_origin
)
{
static_assert
(
nDim
==
BlockSrcDesc
::
GetNumOfDimension
()
&&
nDim
==
BlockDstDesc
::
GetNumOfDimension
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
ThreadSliceLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
ThreadSliceLengths
{}
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
mThreadClusterDesc
.
GetElementSize
(),
"wrong! BlockSize too small"
);
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
const
auto
thread_cluster_id
=
mThreadClusterDesc
.
CalculateClusterIndex
(
get_thread_local_1d_id
());
const
auto
thread_data_id_begin
=
thread_cluster_id
*
ThreadSliceLengths
{};
mThreadwiseLoad
.
SetSrcSliceOrigin
(
src_block_slice_origin
+
thread_data_id_begin
);
mThreadwiseLoad
.
SetDstSliceOrigin
(
make_zero_multi_index
<
nDim
>
());
mThreadwiseStore
.
SetSrcSliceOrigin
(
make_zero_multi_index
<
nDim
>
());
mThreadwiseStore
.
SetDstSliceOrigin
(
dst_block_slice_origin
+
thread_data_id_begin
);
}
}
__device__
static
constexpr
index_t
GetThreadBufferSize
()
{
return
ThreadBufferDesc
::
GetElementSpace
();
}
template
<
typename
BlockSrcData
,
typename
ThreadBufferData
>
__device__
void
RunLoadThreadBuffer
(
const
BlockSrcData
*
p_block_src
,
ThreadBufferData
*
p_thread_buffer
)
const
{
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
mThreadwiseLoad
.
Run
(
p_block_src
,
p_thread_buffer
);
}
}
template
<
typename
ThreadBufferData
,
typename
BlockDstData
>
__device__
void
RunStoreThreadBuffer
(
const
ThreadBufferData
*
p_thread_buffer
,
BlockDstData
*
p_block_dst
)
const
{
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
mThreadwiseStore
.
Run
(
p_thread_buffer
,
p_block_dst
);
}
}
template
<
typename
BlockSrcData
,
typename
BlockDstData
>
__device__
void
Run
(
const
BlockSrcData
*
p_block_src
,
BlockDstData
*
p_block_dst
)
const
{
static_assert
(
ThreadBufferAddressSpace
==
AddressSpace
::
Vgpr
,
"wrong! This function use vgpr as its thread "
"buffer. However, you have set RunLoadThreadBuffer and RunStoreThreadBuffer "
"to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. "
"Behavior may be different"
);
BlockSrcData
p_thread_buffer
[
GetThreadBufferSize
()];
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
RunLoadThreadBuffer
(
p_block_src
,
p_thread_buffer
);
// if there is type conversion, it's done during store
RunStoreThreadBuffer
(
p_thread_buffer
,
p_block_dst
);
}
}
template
<
typename
T
,
bool
PositiveDirection
>
__device__
void
MoveSrcSliceWindow
(
const
T
&
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
{
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
mThreadwiseLoad
.
MoveSrcSliceWindow
(
step_sizes
,
positive_direction
);
}
}
template
<
typename
T
,
bool
PositiveDirection
>
__device__
void
MoveDstSliceWindow
(
const
T
&
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
{
if
(
BlockSize
==
mThreadClusterDesc
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
mThreadClusterDesc
.
GetElementSize
())
{
mThreadwiseStore
.
MoveDstSliceWindow
(
step_sizes
,
positive_direction
);
}
}
private:
using
ThreadBufferDesc
=
decltype
(
make_native_tensor_descriptor_packed
(
ThreadSliceLengths
{}));
using
ThreadwiseLoad
=
ThreadwiseGenericTensorSliceCopy_v4r2
<
BlockSrcDesc
,
ThreadBufferDesc
,
ThreadSliceLengths
,
SrcDimAccessOrder
,
SrcVectoReadDim
,
SrcDataPerRead
,
1
,
SrcAddressSpace
,
ThreadBufferAddressSpace
,
InMemoryDataOperation
::
Set
,
SrcDataStride
,
1
>
;
using
ThreadwiseStore
=
ThreadwiseGenericTensorSliceCopy_v4r2
<
ThreadBufferDesc
,
BlockDstDesc
,
ThreadSliceLengths
,
DstDimAccessOrder
,
DstVectorWriteDim
,
1
,
DstDataPerWrite
,
ThreadBufferAddressSpace
,
DstAddressSpace
,
DstInMemOp
,
1
,
DstDataStride
>
;
static
constexpr
auto
mThreadClusterDesc
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
ThreadwiseLoad
mThreadwiseLoad
;
ThreadwiseStore
mThreadwiseStore
;
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
deleted
100644 → 0
View file @
b8b2d0a6
#ifndef CK_GRIDWISE_GEMM_HPP
#define CK_GRIDWISE_GEMM_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
ThreadGemmAThreadCopySrcDataPerRead_M
,
index_t
ThreadGemmBThreadCopySrcDataPerRead_N
,
typename
ABlockCopyThreadSliceLengths_K_M
,
typename
ABlockCopyThreadClusterLengths_K_M
,
typename
ABlockCopyThreadClusterArrangeOrder
,
typename
ABlockCopySrcAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_M
,
typename
BBlockCopyThreadSliceLengths_K_N
,
typename
BBlockCopyThreadClusterLengths_K_N
,
typename
BBlockCopyThreadClusterArrangeOrder
,
typename
BBlockCopySrcAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_N
,
typename
CThreadCopySrcDstAccessOrder
,
index_t
CThreadCopySrcDstVectorReadWriteDim
,
index_t
CThreadCopyDstDataPerWrite
>
struct
GridwiseGemmTransposedANormalBNormalC_v1
{
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDstDataPerWrite_M
,
BBlockCopyDstDataPerWrite_N
,
ThreadGemmAThreadCopySrcDataPerRead_M
,
ThreadGemmBThreadCopySrcDataPerRead_N
);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
MPerBlock
>
{},
Number
<
max_lds_align
>
{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
Float
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_b_global
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
a_k_m_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_k_n_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
K
=
a_k_m_global_desc
.
GetLengths
()[
0
];
constexpr
auto
M
=
a_k_m_global_desc
.
GetLengths
()[
1
];
constexpr
auto
N
=
b_k_n_global_desc
.
GetLengths
()[
1
];
// don't do anything if K == 0
if
(
K
==
0
)
{
return
;
}
// lds max alignment
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDstDataPerWrite_M
,
BBlockCopyDstDataPerWrite_N
,
ThreadGemmAThreadCopySrcDataPerRead_M
,
ThreadGemmBThreadCopySrcDataPerRead_N
);
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
MBlockWork
,
NBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
m_block_data_on_global
=
block_work_id
[
Number
<
0
>
{}]
*
MPerBlock
;
const
index_t
n_block_data_on_global
=
block_work_id
[
Number
<
1
>
{}]
*
NPerBlock
;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
MPerBlock
>
{},
Number
<
max_lds_align
>
{});
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_k_m_global_desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
a_k_m_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_K_M
,
ABlockCopyThreadClusterLengths_K_M
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
Sequence
<
0
,
1
>
,
ABlockCopySrcVectorReadDim
,
1
,
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_M
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
make_multi_index
(
0
,
m_block_data_on_global
),
make_multi_index
(
0
,
0
));
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_k_n_global_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
b_k_n_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_K_N
,
BBlockCopyThreadClusterLengths_K_N
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
Sequence
<
0
,
1
>
,
BBlockCopySrcVectorReadDim
,
1
,
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_N
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
make_multi_index
(
0
,
n_block_data_on_global
),
make_multi_index
(
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
a_k_m_block_desc
);
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
b_k_n_block_desc
);
// sanity check
static_assert
(
MPerBlock
%
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
NPerBlock
%
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
MPerBlock
/
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
NPerBlock
/
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_m0m1_n0n1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
MPerThread
>
{},
Number
<
GemmNRepeat
*
NPerThread
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
decltype
(
c_m0m1_n0n1_thread_mtx_desc
),
MPerThread
,
NPerThread
,
KPerThread
,
MLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
ThreadGemmAThreadCopySrcDataPerRead_M
,
ThreadGemmBThreadCopySrcDataPerRead_N
>
{};
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
Float
*
p_a_block_double
=
p_shared_block
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_mtx_desc
,
p_c_thread
);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
constexpr
auto
a_block_slice_copy_step
=
Sequence
<
KPerBlock
,
0
>
{};
constexpr
auto
b_block_slice_copy_step
=
Sequence
<
KPerBlock
,
0
>
{};
Float
*
p_a_block_even
=
p_a_block_double
;
Float
*
p_b_block_even
=
p_b_block_double
;
Float
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
Float
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
<
K
-
2
*
KPerBlock
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_even
,
p_b_block_even
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_odd
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_odd
);
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_odd
,
p_b_block_odd
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_even
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_even
);
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space_size
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space_size
);
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
p_b_block_double
+
b_block_space_size
,
p_c_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
}
}
// input: register to global memory
{
constexpr
index_t
M1
=
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
M0
=
M
/
M1
;
constexpr
index_t
N1
=
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
N0
=
N
/
N1
;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr
auto
c_m0_m1_n0_n1_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
MPerThread
,
GemmNRepeat
,
NPerThread
>
{});
constexpr
auto
c_m0_m1_n0_n1_global_desc
=
transform_tensor_descriptor
(
c_m_n_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
M0
,
M1
>>
{},
UnMerge
<
Sequence
<
N0
,
N1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_global_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
.
GetLengths
()),
CThreadCopySrcDstAccessOrder
,
CThreadCopySrcDstVectorReadWriteDim
,
1
,
CThreadCopyDstDataPerWrite
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
>
(
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
m_thread_data_on_global
/
M1
,
m_thread_data_on_global
%
M1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
%
N1
))
.
Run
(
p_c_thread
,
p_c_global
);
}
}
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_b_global
,
Float
*
__restrict__
p_c_global
)
const
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
Float
);
__shared__
Float
p_shared_block
[
shared_block_size
];
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
p_shared_block
);
}
};
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
ThreadGemmAThreadCopySrcDataPerRead_M
,
index_t
ThreadGemmBThreadCopySrcDataPerRead_N
,
typename
ABlockCopyThreadSliceLengths_K0_K1_K2_M
,
typename
ABlockCopyThreadClusterLengths_K0_K1_K2_M
,
typename
ABlockCopyThreadClusterArrangeOrder
,
typename
ABlockCopySrcAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_M
,
typename
BBlockCopyThreadSliceLengths_K0_K1_K2_N
,
typename
BBlockCopyThreadClusterLengths_K0_K1_K2_N
,
typename
BBlockCopyThreadClusterArrangeOrder
,
typename
BBlockCopySrcAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_N
,
typename
CThreadCopySrcDstAccessOrder
,
index_t
CThreadCopySrcDstVectorReadWriteDim
,
index_t
CThreadCopyDstDataPerWrite
>
struct
GridwiseGemmTransposedANormalBNormalC_v2
{
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDstDataPerWrite_M
,
BBlockCopyDstDataPerWrite_N
,
ThreadGemmAThreadCopySrcDataPerRead_M
,
ThreadGemmBThreadCopySrcDataPerRead_N
);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
MPerBlock
>
{},
Number
<
max_lds_align
>
{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space_size
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
Float
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_b_global
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_k0_k1_k2_m_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_k0_k1_k2_n_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
K0
=
a_k0_k1_k2_m_global_desc
.
GetLengths
()[
I0
];
constexpr
auto
K1
=
a_k0_k1_k2_m_global_desc
.
GetLengths
()[
I1
];
constexpr
auto
K
=
a_k0_k1_k2_m_global_desc
.
GetLengths
()[
I2
];
constexpr
auto
M
=
c_m_n_global_desc
.
GetLengths
()[
I0
];
constexpr
auto
N
=
c_m_n_global_desc
.
GetLengths
()[
I1
];
// don't do anything if K == 0
if
(
K
==
0
)
{
return
;
}
// lds max alignment
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDstDataPerWrite_M
,
BBlockCopyDstDataPerWrite_N
,
ThreadGemmAThreadCopySrcDataPerRead_M
,
ThreadGemmBThreadCopySrcDataPerRead_N
);
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
MBlockWork
,
NBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
m_block_data_on_global
=
block_work_id
[
I0
]
*
MPerBlock
;
const
index_t
n_block_data_on_global
=
block_work_id
[
I1
]
*
NPerBlock
;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k0_k1_k2_m_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
1
,
KPerBlock
,
MPerBlock
>
{},
Number
<
max_lds_align
>
{});
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_k0_k1_k2_m_global_desc
),
decltype
(
a_k0_k1_k2_m_block_desc
),
decltype
(
a_k0_k1_k2_m_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_K0_K1_K2_M
,
ABlockCopyThreadClusterLengths_K0_K1_K2_M
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
ABlockCopySrcVectorReadDim
,
3
,
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_M
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
make_multi_index
(
0
,
0
,
0
,
m_block_data_on_global
),
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k0_k1_k2_n_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
1
,
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_k0_k1_k2_n_global_desc
),
decltype
(
b_k0_k1_k2_n_block_desc
),
decltype
(
b_k0_k1_k2_n_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_K0_K1_K2_N
,
BBlockCopyThreadClusterLengths_K0_K1_K2_N
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
BBlockCopySrcVectorReadDim
,
3
,
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_N
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
make_multi_index
(
0
,
0
,
0
,
n_block_data_on_global
),
make_multi_index
(
0
,
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
unfold_tensor_descriptor
(
a_k0_k1_k2_m_block_desc
,
I0
,
I2
));
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
unfold_tensor_descriptor
(
b_k0_k1_k2_n_block_desc
,
I0
,
I2
));
// sanity check
static_assert
(
MPerBlock
%
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
NPerBlock
%
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
MPerBlock
/
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
NPerBlock
/
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_m0m1_n0n1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
MPerThread
>
{},
Number
<
GemmNRepeat
*
NPerThread
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
decltype
(
c_m0m1_n0n1_thread_mtx_desc
),
MPerThread
,
NPerThread
,
KPerThread
,
MLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
ThreadGemmAThreadCopySrcDataPerRead_M
,
ThreadGemmBThreadCopySrcDataPerRead_N
>
{};
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_k1_k2_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space_size
=
math
::
integer_least_multiple
(
b_k0_k1_k2_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
Float
*
p_a_block_double
=
p_shared_block
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_mtx_desc
,
p_c_thread
);
for
(
index_t
k0
=
0
;
k0
<
K0
;
++
k0
)
{
for
(
index_t
k1
=
0
;
k1
<
K1
;
++
k1
)
{
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
constexpr
auto
a_block_slice_copy_step
=
Sequence
<
0
,
0
,
KPerBlock
,
0
>
{};
constexpr
auto
b_block_slice_copy_step
=
Sequence
<
0
,
0
,
KPerBlock
,
0
>
{};
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space_size
;
Float
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space_size
;
Float
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space_size
:
p_a_block_double
;
Float
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space_size
:
p_b_block_double
;
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_slice_copy_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_block_slice_copy_step
,
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space_size
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space_size
);
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
p_b_block_double
+
b_block_space_size
,
p_c_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
}
}
// reset slice windoww on K2 dimension, then move forward on K1 dimension
a_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
0
,
K
-
KPerBlock
,
0
>
{},
False
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
0
,
K
-
KPerBlock
,
0
>
{},
False
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
1
,
0
,
0
>
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
1
,
0
,
0
>
{},
True
);
}
// reset slice windoww on K1 dimension, then move forward on K0 dimension
a_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
K1
,
0
,
0
>
{},
False
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
K1
,
0
,
0
>
{},
False
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
1
,
0
,
0
,
0
>
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
1
,
0
,
0
,
0
>
{},
True
);
}
// input: register to global memory
{
constexpr
index_t
M1
=
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
M0
=
M
/
M1
;
constexpr
index_t
N1
=
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
N0
=
N
/
N1
;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr
auto
c_m0_m1_n0_n1_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
MPerThread
,
GemmNRepeat
,
NPerThread
>
{});
constexpr
auto
c_m0_m1_n0_n1_global_desc
=
transform_tensor_descriptor
(
c_m_n_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
M0
,
M1
>>
{},
UnMerge
<
Sequence
<
N0
,
N1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_global_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
.
GetLengths
()),
CThreadCopySrcDstAccessOrder
,
CThreadCopySrcDstVectorReadWriteDim
,
1
,
CThreadCopyDstDataPerWrite
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
>
(
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
m_thread_data_on_global
/
M1
,
m_thread_data_on_global
%
M1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
%
N1
))
.
Run
(
p_c_thread
,
p_c_global
);
}
}
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_b_global
,
Float
*
__restrict__
p_c_global
)
const
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
Float
);
__shared__
Float
p_shared_block
[
shared_block_size
];
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
p_shared_block
);
}
};
}
// namespace ck
#endif
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment