Unverified Commit 52c3fe05 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Refactor for MIOpen integration (#4)

Refactor, so can bring multi-index transformation and padding support into MIOpen
parent 9aaeacc8
...@@ -47,14 +47,17 @@ include_directories(BEFORE ...@@ -47,14 +47,17 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm ${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm
${PROJECT_SOURCE_DIR}/external/include
${PROJECT_SOURCE_DIR}/driver/include ${PROJECT_SOURCE_DIR}/driver/include
${PROJECT_BINARY_DIR}/composable_kernel/include/utility ${PROJECT_BINARY_DIR}/composable_kernel/include/utility
) )
if(DEVICE_BACKEND STREQUAL "AMD") if(DEVICE_BACKEND STREQUAL "AMD")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config_amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
elseif(DEVICE_BACKEND STREQUAL "NVIDIA") elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config_nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
endif() endif()
add_subdirectory(driver) add_subdirectory(driver)
#ifndef CK_CONVOLUTION_COMMON_HPP
#define CK_CONVOLUTION_COMMON_HPP
namespace ck {
enum ConvolutionDirection
{
Forward,
BackwardData,
BackwardWeight
};
} // namespace ck
#endif
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW #define CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp" #include "threadwise_tensor_slice_copy.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
#include "blockwise_3d_tensor_op.hpp" #include "blockwise_3d_tensor_op.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
...@@ -125,38 +125,38 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -125,38 +125,38 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
auto blockwise_in_copy = auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockSize,
decltype(in_c_h_w_n_global_desc), decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc), decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()), decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopySubLengths_CHWN, InBlockCopySubLengths_CHWN,
InBlockCopyClusterLengths_CHWN, InBlockCopyClusterLengths_CHWN,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
3, 3,
InBlockCopyDataPerAccess_N, InBlockCopyDataPerAccess_N,
InBlockCopyDataPerAccess_N>({0, 0, 0, 0}, InBlockCopyDataPerAccess_N>({0, 0, 0, 0}, {0, 0, 0, 0});
{0, 0, 0, 0});
// blockwise wei copy // blockwise wei copy
// format is [CPerBlock, X * KPerBlock] // format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc), decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()), decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_CK, WeiBlockCopySubLengths_CK,
WeiBlockCopyClusterLengths_CK, WeiBlockCopyClusterLengths_CK,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
1, 1,
WeiBlockCopyDataPerAccess_K, WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0}); WeiBlockCopyDataPerAccess_K>({0, 0},
{0, 0});
// a series of blockwise batched GEMM // a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix // C_matrix += transpose(A_matrix) * B_matrix
...@@ -318,14 +318,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -318,14 +318,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
#if 1 #if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
decltype(out_10d_global_desc), decltype(out_10d_thread_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_global_desc),
arithmetic_sequence_gen<0, 10, 1>::type, decltype(out_10d_thread_desc.GetLengths()),
9, arithmetic_sequence_gen<0, 10, 1>::type,
OutThreadCopyDataPerAccess_N, 9,
OutThreadCopyDataPerAccess_N>( OutThreadCopyDataPerAccess_N,
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>()) OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global); .Run(p_out_thread, p_out_thread_on_global);
#elif 0 #elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
...@@ -388,14 +389,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -388,14 +389,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
#if 1 #if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
decltype(out_10d_global_desc), decltype(out_10d_thread_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_global_desc),
arithmetic_sequence_gen<0, 10, 1>::type, decltype(out_10d_thread_desc.GetLengths()),
9, arithmetic_sequence_gen<0, 10, 1>::type,
OutThreadCopyDataPerAccess_N, 9,
OutThreadCopyDataPerAccess_N>( OutThreadCopyDataPerAccess_N,
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>()) OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global); .Run(p_out_thread, p_out_thread_on_global);
#elif 0 #elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
...@@ -127,9 +127,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -127,9 +127,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
auto blockwise_in_copy = auto blockwise_in_copy =
#if 0 #if 0
BlockwiseGenericTensorSliceCopy_v1 BlockwiseGenericTensorSliceCopy_v1_deprecated
#else #else
BlockwiseGenericTensorSliceCopy_v2 BlockwiseGenericTensorSliceCopy_v2_deprecated
#endif #endif
<BlockSize, <BlockSize,
decltype(in_c_h_w_n_global_desc), decltype(in_c_h_w_n_global_desc),
...@@ -149,9 +149,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -149,9 +149,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
// format is [CPerBlock, X * KPerBlock] // format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
#if 0 #if 0
BlockwiseGenericTensorSliceCopy_v1 BlockwiseGenericTensorSliceCopy_v1_deprecated
#else #else
BlockwiseGenericTensorSliceCopy_v2 BlockwiseGenericTensorSliceCopy_v2_deprecated
#endif #endif
<BlockSize, <BlockSize,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
...@@ -406,14 +406,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -406,14 +406,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
#if 1 #if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
decltype(out_10d_global_desc), decltype(out_10d_thread_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_global_desc),
arithmetic_sequence_gen<0, 10, 1>::type, decltype(out_10d_thread_desc.GetLengths()),
9, arithmetic_sequence_gen<0, 10, 1>::type,
OutThreadCopyDataPerAccess_N, 9,
OutThreadCopyDataPerAccess_N>( OutThreadCopyDataPerAccess_N,
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>()) OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global); .Run(p_out_thread, p_out_thread_on_global);
#elif 0 #elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
...@@ -476,14 +477,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer ...@@ -476,14 +477,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
n_block_data_begin + n_thread_data_begin); n_block_data_begin + n_thread_data_begin);
#if 1 #if 1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
decltype(out_10d_global_desc), decltype(out_10d_thread_desc),
decltype(out_10d_thread_desc.GetLengths()), decltype(out_10d_global_desc),
arithmetic_sequence_gen<0, 10, 1>::type, decltype(out_10d_thread_desc.GetLengths()),
9, arithmetic_sequence_gen<0, 10, 1>::type,
OutThreadCopyDataPerAccess_N, 9,
OutThreadCopyDataPerAccess_N>( OutThreadCopyDataPerAccess_N,
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>()) OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
make_zero_array<index_t, 10>())
.Run(p_out_thread, p_out_thread_on_global); .Run(p_out_thread, p_out_thread_on_global);
#elif 0 #elif 0
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc), ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_PADDED_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_PADDED_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
#include "blockwise_tensor_slice_copy.hpp" #include "blockwise_tensor_slice_copy.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
#include "blockwise_tensor_slice_copy.hpp" #include "blockwise_tensor_slice_copy.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp" #include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp" #include "blockwise_2d_tensor_op.hpp"
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
...@@ -128,7 +128,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -128,7 +128,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1< auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockSize, BlockSize,
Float, Float,
decltype(in_c_n1_b_n2_global_merged_desc), decltype(in_c_n1_b_n2_global_merged_desc),
...@@ -155,20 +155,19 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -155,20 +155,19 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockSize,
Float, Float,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc), decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()), decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_C_K, WeiBlockCopySubLengths_C_K,
WeiBlockCopyClusterLengths_C_K, WeiBlockCopyClusterLengths_C_K,
Sequence<0, 1>, // thread_arrange_order [C, K] Sequence<0, 1>, // thread_arrange_order [C, K]
Sequence<0, 1>, // src_access_order [C, K] Sequence<0, 1>, // src_access_order [C, K]
Sequence<0, 1>, // dst_access_order [C, K] Sequence<0, 1>, // dst_access_order [C, K]
WeiBlockCopyDataPerAccess_K, WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>( WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global}, {0, 0});
{0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
...@@ -125,7 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer ...@@ -125,7 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
const auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1< const auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockSize, BlockSize,
Float, Float,
decltype(in_c_n1_b_n2_global_merged_desc), decltype(in_c_n1_b_n2_global_merged_desc),
...@@ -152,20 +152,19 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer ...@@ -152,20 +152,19 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
const auto blockwise_wei_copy = const auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockSize,
Float, Float,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc), decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()), decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_C_K, WeiBlockCopySubLengths_C_K,
WeiBlockCopyClusterLengths_C_K, WeiBlockCopyClusterLengths_C_K,
Sequence<0, 1>, // thread_arrange_order [C, K] Sequence<0, 1>, // thread_arrange_order [C, K]
Sequence<0, 1>, // src_access_order [C, K] Sequence<0, 1>, // src_access_order [C, K]
Sequence<0, 1>, // dst_access_order [C, K] Sequence<0, 1>, // dst_access_order [C, K]
WeiBlockCopyDataPerAccess_K, WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>( WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global}, {0, 0});
{0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
......
...@@ -2,24 +2,71 @@ ...@@ -2,24 +2,71 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "tensor_descriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy_deprecated.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp" #include "convolution_common.hpp"
namespace ck { namespace ck {
// define B = merge(N0, Ho, Wo) template <ConvolutionDirection>
struct make_wei_e_k_global_desc_v4r1;
template <>
struct make_wei_e_k_global_desc_v4r1<ConvolutionDirection::Forward>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
return reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(WeiDesc{}, I1, I3), Sequence<1, 0>{});
}
};
template <>
struct make_wei_e_k_global_desc_v4r1<ConvolutionDirection::BackwardWeight>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto wei_k_c_y_x_global_desc = WeiDesc{};
constexpr index_t K = wei_k_c_y_x_global_desc.GetLength(I0);
constexpr index_t C = wei_k_c_y_x_global_desc.GetLength(I1);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
return transform_tensor_descriptor(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3),
make_tuple(Merge<Sequence<C, Y * X>>{}, PassThrough<K>{}),
make_tuple(Sequence<1, 2>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
};
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
class Float, typename Float,
class InGlobalDesc, typename AccDataType,
class WeiGlobalDesc, typename InGlobalDesc,
class OutGlobalDesc, typename WeiGlobalDesc,
class ConvStrides, typename OutGlobalDesc,
class ConvDilations, typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
ConvolutionDirection ConvDirection,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
...@@ -33,18 +80,18 @@ template <index_t GridSize, ...@@ -33,18 +80,18 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA, index_t GemmDataPerReadA,
index_t GemmDataPerReadB, index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_N1_B_N2, typename InBlockCopySubLengths_E_N1_B_N2,
class InBlockCopyClusterLengths_E_N1_B_N2, typename InBlockCopyClusterLengths_E_N1_B_N2,
class InBlockCopyThreadClusterArrangeOrder, typename InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder, typename InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder, typename InBlockCopyDstAccessOrder,
index_t InBlockCopySrcDataPerRead_B, index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2, index_t InBlockCopyDstDataPerWrite_N2,
class WeiBlockCopySubLengths_E_K, typename WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K, typename WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder, typename WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder, typename WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder, typename WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K> index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...@@ -53,6 +100,22 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -53,6 +100,22 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
constexpr auto global_address_space =
integral_constant<AddressSpace, AddressSpace::global>{};
static_assert(ConvDirection == ConvolutionDirection::Forward ||
ConvDirection == ConvolutionDirection::BackwardWeight,
"wrong! this kernel only support convolution forward and backward-weight");
// this is a mess // this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters // TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr index_t N1 = GemmNRepeat; constexpr index_t N1 = GemmNRepeat;
...@@ -63,24 +126,18 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -63,24 +126,18 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
0, 0,
"wrong!"); "wrong!");
constexpr auto I0 = Number<0>{}; constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto I1 = Number<1>{}; constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto I2 = Number<2>{}; constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr auto I3 = Number<3>{};
constexpr auto I5 = Number<5>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
...@@ -106,46 +163,51 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -106,46 +163,51 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
"be violated"); "be violated");
// divide block work by [K, B] // divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
"wrong! cannot divide work evenly among block"); "wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock; constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock; constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc = constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{}); make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id = const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
// input tensor // input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo] // global tensor in global memory
constexpr auto in_n0_n1_n2_h_w_global_desc = constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{}) in_n_c_hi_wi_global_desc,
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{}) make_tuple(
.Fold(I0, Number<N1>{}, Number<N2>{}) PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
.Extract(Sequence<0, 1, 2, 4, 5>{}); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc = constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{}) in_n_c_hip_wip_global_desc,
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{}) make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
.Extract(Sequence<1, 2, 3>{}); PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
Sequence<0, 1, 2>{},
Sequence<4>{}, // global tensor in global memory, src of blockwise copy
Sequence<3, 6, 7>{}, constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor(
Sequence<5>{}); in_n0_n1_n2_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{},
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy PassThrough<N1>{},
Merge<Sequence<N0, Ho, Wo>>{},
PassThrough<N2>{}),
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_e_n1_b_n2_block_desc = make_native_tensor_descriptor_aligned(
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{}); Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
// this check is ad-hoc // this check is ad-hoc
...@@ -154,12 +216,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -154,12 +216,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0, static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not satisfied"); "GemmDataPerReadB alignment requirement is not satisfied");
// input blockwise copy // input tensor blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(in_e_n1_b_n2_global_merged_desc), decltype(in_e_n1_b_n2_global_desc),
decltype(in_e_n1_b_n2_block_desc), decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()), decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2, InBlockCopySubLengths_E_N1_B_N2,
...@@ -174,21 +234,27 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -174,21 +234,27 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // global tensor in global memory, src of blockwise copy
// It is constructed differently, depending on whether forward or backward weight
// convolution
constexpr auto wei_e_k_global_desc = constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); make_wei_e_k_global_desc_v4r1<ConvDirection>{}(wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy // block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
Sequence<EPerBlock, KPerBlock>{}, Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{}); Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS // this check is ad-hoc
// slice a tensor, and copy it into another tensor // TODO: need to properly implement tensor descriptor with multiple alignment
// this copy operator already have blockwise offset built-in // requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// weight tensor blockwise copy
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
...@@ -204,15 +270,18 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -204,15 +270,18 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS // a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register // register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc = constexpr auto b_e_n1bn2_block_mtx_desc = make_ConstantMatrixDescriptor(
make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3)); in_e_n1_b_n2_block_desc.GetLength(I0),
in_e_n1_b_n2_block_desc.GetLength(I1) * in_e_n1_b_n2_block_desc.GetLength(I2) *
in_e_n1_b_n2_block_desc.GetLength(I3),
in_e_n1_b_n2_block_desc.GetStride(I0));
// sanity check // sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
...@@ -258,17 +327,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -258,17 +327,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
__shared__ Float p_wei_block_double[2 * wei_block_space]; __shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output // register allocation for output
Float p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()]; AccDataType p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_n1n2_thread_mtx_desc, p_out_thread); threadwise_matrix_set_zero(c_k0k1_n1n2_thread_mtx_desc, p_out_thread);
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
blockwise_in_copy.template Run<Float, Float, address_space_t::global>( blockwise_in_copy.Run(
p_in_global, p_in_block_double); p_in_global, p_in_block_double, global_address_space, generic_address_space);
blockwise_wei_copy.template Run<Float, Float, address_space_t::global>( blockwise_wei_copy.Run(
p_wei_global, p_wei_block_double); p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -299,12 +368,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -299,12 +368,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy blockwise_in_copy.RunLoadThreadBuffer(
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>( p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
p_in_global, p_in_thread_buffer); blockwise_wei_copy.RunLoadThreadBuffer(
blockwise_wei_copy p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
...@@ -317,60 +384,84 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -317,60 +384,84 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: tail // LDS double buffer: tail
{ {
// even iteration constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True); if(has_two_iteration_left) // if has 2 iteration left
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True); {
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
__syncthreads(); // LDS double buffer: load last data from device mem
blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS doubel buffer: load next data from device mem // LDS double buffer: GEMM on 2nd-last data
blockwise_in_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>( blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: store last data to LDS
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
// LDS double buffer: store next data to LDS __syncthreads();
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
// odd iteration // LDS double buffer: GEMM on last data
__syncthreads(); blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double + wei_block_space, blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
p_in_block_double + in_block_space, }
p_out_thread);
} }
// copy output: register to global memory // copy output: register to global memory
{ {
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t K0 = K / K1;
// define tensor descriptor for threadwise copy // define output tensor descriptor for threadwise copy
// output memory layout descriptor in register, src of threadwise copy // thread output tensor, src of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor_packed( constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{}); Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
// output memory layout descriptor in device memory // global output tensor
constexpr auto out_n0_n1_n2_k0_k1_h_w_global_mem_desc = constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor(
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}).Fold(I0, Number<N1>{}, Number<N2>{}); out_n_k_ho_wo_global_desc,
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
// output merged global tensor descriptor, dst of threadwise copy UnMerge<Sequence<K0, K1>>{},
constexpr auto out_k0_k1_n1_b_n2_global_merged_desc = PassThrough<Ho>{},
make_ConstantMergedTensorDescriptor(out_n0_n1_n2_k0_k1_h_w_global_mem_desc, PassThrough<Wo>{}),
Sequence<3>{}, make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
Sequence<4>{}, make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}, Sequence<6>{}));
Sequence<1>{},
Sequence<0, 5, 6>{}, // global output tensor, dst of threadwise copy
Sequence<2>{}); constexpr auto out_k0_k1_n1_b_n2_global_desc = transform_tensor_descriptor(
out_n0_n1_n2_k0_k1_ho_wo_global_desc,
make_tuple(PassThrough<K0>{},
PassThrough<K1>{},
PassThrough<N1>{},
Merge<Sequence<N0, Ho, Wo>>{},
PassThrough<N2>{}),
make_tuple(Sequence<3>{},
Sequence<4>{},
Sequence<1>{},
Sequence<0, 5, 6>{},
Sequence<2>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
...@@ -383,26 +474,23 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -383,26 +474,23 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2; b_block_data_on_global + c_thread_mtx_on_block.col / N2;
ThreadwiseGenericTensorSliceCopy_v2r1< ThreadwiseGenericTensorSliceCopy_v4r2<decltype(out_k0_k1_n1_b_n2_thread_desc),
decltype(out_k0_k1_n1_b_n2_thread_mem_desc), decltype(out_k0_k1_n1_b_n2_global_desc),
decltype(out_k0_k1_n1_b_n2_global_merged_desc), decltype(
decltype(out_k0_k1_n1_b_n2_thread_mem_desc.GetLengths()), out_k0_k1_n1_b_n2_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 5, 1>::type, arithmetic_sequence_gen<0, 5, 1>::type,
arithmetic_sequence_gen<0, 5, 1>::type, 3,
3, 1,
3, 1>({0, 0, 0, 0, 0},
1, {k_thread_data_on_global / K1,
1>({0, 0, 0, 0, 0}, k_thread_data_on_global % K1,
{k_thread_data_on_global / K1, 0,
k_thread_data_on_global % K1, b_thread_data_on_global,
0, 0})
b_thread_data_on_global, .Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
0})
.template Run<Float, Float, address_space_t::generic, address_space_t::global>(
p_out_thread, p_out_global);
} }
} }
}; };
} // namespace ck } // namespace ck
#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP #endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_PADDED_LDS_DOUBLE_BUFFER_HPP #ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_PADDED_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "tensor_descriptor_helper.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "convolution_common.hpp"
namespace ck { namespace ck {
template <ConvolutionDirection>
struct make_wei_e_k_global_desc_v4r1_deprecated;
template <>
struct make_wei_e_k_global_desc_v4r1_deprecated<ConvolutionDirection::Forward>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
return WeiDesc::Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
}
};
template <>
struct make_wei_e_k_global_desc_v4r1_deprecated<ConvolutionDirection::BackwardWeight>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
return make_ConstantMergedTensorDescriptor(
WeiDesc::Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{});
}
};
// define B = merge(N0, Ho, Wo) // define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
typename Float, class Float,
typename InGlobalDesc, class AccDataType,
typename WeiGlobalDesc, class InGlobalDesc,
typename OutGlobalDesc, class WeiGlobalDesc,
typename ConvStrides, class OutGlobalDesc,
typename ConvDilations, class ConvStrides,
typename LeftPads, class ConvDilations,
typename RightPads, ConvolutionDirection ConvDirection,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
...@@ -35,26 +66,42 @@ template <index_t GridSize, ...@@ -35,26 +66,42 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA, index_t GemmDataPerReadA,
index_t GemmDataPerReadB, index_t GemmDataPerReadB,
typename InBlockCopySubLengths_E_N1_B_N2, class InBlockCopySubLengths_E_N1_B_N2,
typename InBlockCopyClusterLengths_E_N1_B_N2, class InBlockCopyClusterLengths_E_N1_B_N2,
typename InBlockCopyThreadClusterArrangeOrder, class InBlockCopyThreadClusterArrangeOrder,
typename InBlockCopySrcAccessOrder, class InBlockCopySrcAccessOrder,
typename InBlockCopyDstAccessOrder, class InBlockCopyDstAccessOrder,
index_t InBlockCopySrcDataPerRead_B, index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2, index_t InBlockCopyDstDataPerWrite_N2,
typename WeiBlockCopySubLengths_E_K, class WeiBlockCopySubLengths_E_K,
typename WeiBlockCopyClusterLengths_E_K, class WeiBlockCopyClusterLengths_E_K,
typename WeiBlockCopyThreadClusterArrangeOrder, class WeiBlockCopyThreadClusterArrangeOrder,
typename WeiBlockCopySrcAccessOrder, class WeiBlockCopySrcAccessOrder,
typename WeiBlockCopyDstAccessOrder, class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K> index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
constexpr auto global_address_space =
integral_constant<AddressSpace, AddressSpace::global>{};
static_assert(ConvDirection == ConvolutionDirection::Forward ||
ConvDirection == ConvolutionDirection::BackwardWeight,
"wrong! this kernel only support convolution forward and backward-weight");
// this is a mess // this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters // TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr index_t N1 = GemmNRepeat; constexpr index_t N1 = GemmNRepeat;
...@@ -65,25 +112,16 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -65,25 +112,16 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
0, 0,
"wrong!"); "wrong!");
constexpr auto I0 = Number<0>{}; constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto I1 = Number<1>{}; constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto I2 = Number<2>{}; constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0); constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1); constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2); constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
...@@ -116,43 +154,39 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -116,43 +154,39 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
constexpr index_t BBlockWork = B / BPerBlock; constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc = constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{}); make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock; const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock; const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor // input tensor
// global memory // tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_hi_wi_global_desc, in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
make_tuple( .StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}), .Fold(I0, Number<N1>{}, Number<N2>{})
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), .Extract(Sequence<0, 1, 2, 4, 5>{});
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
// batch descritpor for device memory
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( constexpr auto in_c_y_x_global_desc =
in_n_c_hip_wip_global_desc, in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
make_tuple(UnMerge<Sequence<N0, N1, N2>>{}, .StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
PassThrough<C>{}, .Extract(Sequence<1, 2, 3>{});
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}), // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc),
Sequence<0, 1, 2>{},
constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor( Sequence<4>{},
in_n0_n1_n2_c_y_ho_x_wo_global_desc, Sequence<3, 6, 7>{},
make_tuple(Merge<Sequence<C, Y, X>>{}, Sequence<5>{});
PassThrough<N1>{},
Merge<Sequence<N0, Ho, Wo>>{},
PassThrough<N2>{}),
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy // memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto in_e_n1_b_n2_block_desc = make_native_tensor_descriptor_aligned( constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{}); Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
// this check is ad-hoc // this check is ad-hoc
...@@ -164,56 +198,51 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -164,56 +198,51 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2_deprecated<
BlockwiseGenericTensorSliceCopy_v4<BlockSize, BlockSize,
decltype(in_e_n1_b_n2_global_desc), decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc), decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()), decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2, InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2, InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder, InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder, InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder, InBlockCopyDstAccessOrder,
2, 2,
3, 3,
InBlockCopySrcDataPerRead_B, InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>( InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // Iensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower( // It is constructed differently, depending on whether forward or backward weight
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{}); // convolution
constexpr auto wei_e_k_global_desc =
make_wei_e_k_global_desc_v4r1_deprecated<ConvDirection>{}(wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned( constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{}, Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{}); Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize, BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K, WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder, WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder, WeiBlockCopyDstAccessOrder,
0, 0,
1, 1,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
...@@ -224,11 +253,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -224,11 +253,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
// register // register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc = make_ConstantMatrixDescriptor( constexpr auto b_e_n1bn2_block_mtx_desc =
in_e_n1_b_n2_block_desc.GetLength(I0), make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
in_e_n1_b_n2_block_desc.GetLength(I1) * in_e_n1_b_n2_block_desc.GetLength(I2) *
in_e_n1_b_n2_block_desc.GetLength(I3),
in_e_n1_b_n2_block_desc.GetStride(I0));
// sanity check // sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
...@@ -240,14 +266,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -240,14 +266,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{}); Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
decltype(a_e_k_block_mtx_desc), decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc), decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k2_n1n2_thread_mtx_desc), decltype(c_k0k1_n1n2_thread_mtx_desc),
GemmMPerThreadSubC, GemmMPerThreadSubC,
GemmNPerThreadSubC, GemmNPerThreadSubC,
GemmMLevel0Cluster, GemmMLevel0Cluster,
...@@ -274,17 +300,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -274,17 +300,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
__shared__ Float p_wei_block_double[2 * wei_block_space]; __shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output // register allocation for output
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()]; AccDataType p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread); threadwise_matrix_set_zero(c_k0k1_n1n2_thread_mtx_desc, p_out_thread);
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
blockwise_in_copy.template Run<Float, Float, address_space_t::global>( blockwise_in_copy.Run(
p_in_global, p_in_block_double); p_in_global, p_in_block_double, global_address_space, generic_address_space);
blockwise_wei_copy.template Run<Float, Float, address_space_t::global>( blockwise_wei_copy.Run(
p_wei_global, p_wei_block_double); p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -315,12 +341,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -315,12 +341,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy blockwise_in_copy.RunLoadThreadBuffer(
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>( p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
p_in_global, p_in_thread_buffer); blockwise_wei_copy.RunLoadThreadBuffer(
blockwise_wei_copy p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
...@@ -343,10 +367,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -343,10 +367,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>( blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer); p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>( blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer); p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
...@@ -369,38 +393,24 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -369,38 +393,24 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
// copy output: register to global memory // copy output: register to global memory
{ {
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t K0 = K / K1;
// define tensor descriptor for threadwise copy // define tensor descriptor for threadwise copy
// output memory layout descriptor in register, src of threadwise copy // output memory layout descriptor in register, src of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed( constexpr auto out_k0_k1_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{}); Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
// output memory layout descriptor in device memory // output memory layout descriptor in device memory
constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor( constexpr auto out_n0_n1_n2_k0_k1_h_w_global_mem_desc =
out_n_k_ho_wo_global_desc, out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}).Fold(I0, Number<N1>{}, Number<N2>{});
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
UnMerge<Sequence<K0, K1>>{},
PassThrough<Ho>{},
PassThrough<Wo>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}, Sequence<6>{}));
// output merged global tensor descriptor, dst of threadwise copy // output merged global tensor descriptor, dst of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_global_desc = transform_tensor_descriptor( constexpr auto out_k0_k1_n1_b_n2_global_merged_desc =
out_n0_n1_n2_k0_k1_ho_wo_global_desc, make_ConstantMergedTensorDescriptor(out_n0_n1_n2_k0_k1_h_w_global_mem_desc,
make_tuple(PassThrough<K0>{}, Sequence<3>{},
PassThrough<K1>{}, Sequence<4>{},
PassThrough<N1>{}, Sequence<1>{},
Merge<Sequence<N0, Ho, Wo>>{}, Sequence<0, 5, 6>{},
PassThrough<N2>{}), Sequence<2>{});
make_tuple(Sequence<3>{},
Sequence<4>{},
Sequence<1>{},
Sequence<0, 5, 6>{},
Sequence<2>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
...@@ -413,31 +423,25 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -413,31 +423,25 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2; b_block_data_on_global + c_thread_mtx_on_block.col / N2;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(out_k0_k1_n1_b_n2_thread_desc), ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
decltype(out_k0_k1_n1_b_n2_global_desc), decltype(out_k0_k1_n1_b_n2_thread_mem_desc),
decltype( decltype(out_k0_k1_n1_b_n2_global_merged_desc),
out_k0_k1_n1_b_n2_thread_desc.GetLengths()), decltype(out_k0_k1_n1_b_n2_thread_mem_desc.GetLengths()),
arithmetic_sequence_gen<0, 5, 1>::type, arithmetic_sequence_gen<0, 5, 1>::type,
3, arithmetic_sequence_gen<0, 5, 1>::type,
1, 3,
1>({0, 0, 0, 0, 0}, 3,
{k_thread_data_on_global / K1, 1,
k_thread_data_on_global % K1, 1>({0, 0, 0, 0, 0},
0, {k_thread_data_on_global / K1,
b_thread_data_on_global, k_thread_data_on_global % K1,
0}) 0,
#if 1 b_thread_data_on_global,
.template Run<Float, Float, address_space_t::generic, address_space_t::global> 0})
#else // tweaking .Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
.template Run_optimized_dst_address_calculation<Float,
Float,
address_space_t::generic,
address_space_t::global>
#endif
(p_out_thread, p_out_global);
} }
} }
}; };
} // namespace ck } // namespace ck
#endif #endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
...@@ -166,7 +166,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer ...@@ -166,7 +166,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1< auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockSize, BlockSize,
Float, Float,
decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc), decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc),
...@@ -196,18 +196,18 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer ...@@ -196,18 +196,18 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
Float, Float,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K, WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder, WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder, WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
...@@ -165,7 +165,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer ...@@ -165,7 +165,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1< auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockSize, BlockSize,
Float, Float,
decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc), decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc),
...@@ -195,18 +195,18 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer ...@@ -195,18 +195,18 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
Float, Float,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K, WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder, WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder, WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
#if 0 #if 0
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP #ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "tensor_descriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy_deprecated.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
namespace ck { namespace ck {
// B = merge(N, Ho, Wo) // B = merge(N, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
class Float, typename Float,
class InGlobalDesc, typename InGlobalDesc,
class WeiGlobalDesc, typename WeiGlobalDesc,
class OutGlobalDesc, typename OutGlobalDesc,
class ConvStrides, typename ConvStrides,
class ConvDilations, typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
...@@ -32,17 +34,17 @@ template <index_t GridSize, ...@@ -32,17 +34,17 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA, index_t GemmDataPerReadA,
index_t GemmDataPerReadB, index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_B, typename InBlockCopySubLengths_E_B,
class InBlockCopyClusterLengths_E_B, typename InBlockCopyClusterLengths_E_B,
class InBlockCopyThreadClusterArrangeOrder, typename InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder, typename InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder, typename InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B, index_t InBlockCopyDataPerAccess_B,
class WeiBlockCopySubLengths_E_K, typename WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K, typename WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder, typename WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder, typename WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder, typename WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K, index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B> index_t OutThreadCopyDataPerAccess_B>
...@@ -56,23 +58,32 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -56,23 +58,32 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto I5 = Number<5>{};
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; constexpr auto generic_address_space =
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; integral_constant<AddressSpace, AddressSpace::generic>{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; constexpr auto global_address_space =
integral_constant<AddressSpace, AddressSpace::global>{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0]; constexpr auto in_n_c_hi_wi_global_desc =
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1]; make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr auto wei_k_c_y_x_global_desc =
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
constexpr auto out_n_k_ho_wo_global_desc =
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1]; constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2]; constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3]; constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0]; constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1]; constexpr index_t ConvStrideW = ConvStrides{}[1];
...@@ -90,50 +101,52 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -90,50 +101,52 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
"be violated"); "be violated");
// divide block work by [K, B] // divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
"wrong! cannot divide work evenly among block"); "wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock; constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock; constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc = constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{}); make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id = const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
// input tensor // input tensor
// tensor descriptor in device memory [N, Ho, Wo] // global mem
constexpr auto in_n_ho_wo_global_desc = constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_h_w_global_desc.Extract(I0, I2, I3) in_n_c_hi_wi_global_desc,
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{}) make_tuple(
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{}); PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
// batch descritpor for device memory make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{}) constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{}) in_n_c_hip_wip_global_desc,
.Extract(Sequence<1, 2, 3>{}); make_tuple(PassThrough<N>{},
PassThrough<C>{},
// merged tensor descriptor in device memory [E, B], src of blockwise copy Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
constexpr auto in_e_b_global_desc = Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_ConstantMergedTensorDescriptor(in_c_y_x_global_desc.Embed(in_n_ho_wo_global_desc), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
Sequence<0, 1, 2>{}, make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
Sequence<3, 4, 5>{});
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
// memory layout descriptor in LDS [E, B], dst of blockwise copy in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// LDS mem
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto in_e_b_block_desc = constexpr auto in_e_b_block_desc =
make_ConstantTensorDescriptor_packed(Sequence<EPerBlock, BPerBlock>{}); make_native_tensor_descriptor_packed(Sequence<EPerBlock, BPerBlock>{});
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc), decltype(in_e_b_block_desc),
decltype(in_e_b_block_desc.GetLengths()), decltype(in_e_b_block_desc.GetLengths()),
...@@ -149,13 +162,13 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -149,13 +162,13 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
{0, b_block_data_on_global}, {0, 0}); {0, b_block_data_on_global}, {0, 0});
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // global mem
constexpr auto wei_e_k_global_desc = constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
// tensor descriptor in LDS, dst of blockwise copy // LDS
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
Sequence<EPerBlock, KPerBlock>{}, Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{}); Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
...@@ -165,11 +178,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -165,11 +178,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0, static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied"); "GemmDataPerReadA alignment requirement is not satisfied");
// operator for blockwise copy of weight into LDS // weight blockwise copy
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
...@@ -247,14 +258,12 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -247,14 +258,12 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread); threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
blockwise_in_copy.template Run<Float, address_space_t::global>(p_in_global, blockwise_in_copy.Run(
p_in_block_double); p_in_global, p_in_block_double, global_address_space, generic_address_space);
blockwise_wei_copy.template Run<Float, address_space_t::global>(p_wei_global, blockwise_wei_copy.Run(
p_wei_block_double); p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -285,10 +294,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -285,10 +294,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, address_space_t::global>( blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer); p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, address_space_t::global>( blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer); p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
...@@ -301,50 +310,51 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -301,50 +310,51 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration if(has_two_iteration_left) // if has 2 iteration left
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True); {
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True); Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
__syncthreads(); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
// LDS doubel buffer: load next data from device mem __syncthreads();
blockwise_in_copy.template RunLoadThreadBuffer<Float, address_space_t::global>(
p_in_global, p_in_thread_buffer); // LDS double buffer: load last data from device mem
blockwise_wei_copy.template RunLoadThreadBuffer<Float, address_space_t::global>( blockwise_in_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer); p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: GEMM on current data // LDS double buffer: store last data to LDS
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
// LDS double buffer: store next data to LDS __syncthreads();
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
// odd iteration // LDS double buffer: GEMM on current data
__syncthreads(); blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double + wei_block_space, blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
p_in_block_double + in_block_space, }
p_out_thread);
} }
// copy output: register to global memory // copy output: register to global memory
{ {
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
// define tensor descriptor for threadwise copy
// output global descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_b_global_desc = make_ConstantMergedTensorDescriptor(
out_n_k_h_w_global_desc, Sequence<1>{}, Sequence<0, 2, 3>{});
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
...@@ -356,47 +366,48 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -356,47 +366,48 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col; b_block_data_on_global + c_thread_mtx_on_block.col;
// This is a hack, because slicing a merged dimension is not supported yet. // src descriptor
// This should be replaced with logic above, once slicing a merged dimension support constexpr auto out_k0_k1_b0_b1_thread_desc = make_native_tensor_descriptor_packed(
// become available Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat, GemmNPerThreadSubC>{});
// dst descriptor
constexpr auto out_k0_k1_b_global_desc =
make_ConstantMergedTensorDescriptor(out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}),
Sequence<1>{},
Sequence<2>{},
Sequence<0, 3, 4>{});
// src descriptor
constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat * GemmNPerThreadSubC>{});
using OutThreadCopySliceLengths =
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
auto threadwise_out_copy =
ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_k0_k1_b_thread_desc),
decltype(out_k0_k1_b_global_desc),
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 3, 1>::type,
arithmetic_sequence_gen<0, 3, 1>::type,
2,
2,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>(
{0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global});
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
{
threadwise_out_copy
.template Run<Float, address_space_t::generic, address_space_t::global>(
p_out_thread, p_out_global);
threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True); // dst descriptor
threadwise_out_copy.MoveDstSliceWindow(Sequence<0, 0, B1>{}, True); constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
} constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t K0 = K / K1;
constexpr index_t B0 = B / B1;
constexpr auto out_k_b_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto out_k0_k1_b0_b1_global_desc = transform_tensor_descriptor(
out_k_b_global_desc,
make_tuple(UnMerge<Sequence<K0, K1>>{}, UnMerge<Sequence<B0, B1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// output threadwise copy
ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(out_k0_k1_b0_b1_thread_desc),
decltype(out_k0_k1_b0_b1_global_desc),
decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 4, 1>::type,
3,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global / B1,
b_thread_data_on_global % B1})
#if 1
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
#else // tweaking
.Run_optimized_dst_address_calculation(
p_out_thread, p_out_global, generic_address_space, global_address_space);
#endif
} }
} }
}; };
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_PADDED_LDS_DOUBLE_BUFFER_HPP #ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_DEPRECATRD_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_PADDED_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_DEPRECATRD_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "tensor_descriptor_helper.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
namespace ck { namespace ck {
// B = merge(N, Ho, Wo) // B = merge(N, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
typename Float, class Float,
typename InGlobalDesc, class InGlobalDesc,
typename WeiGlobalDesc, class WeiGlobalDesc,
typename OutGlobalDesc, class OutGlobalDesc,
typename ConvStrides, class ConvStrides,
typename ConvDilations, class ConvDilations,
typename LeftPads,
typename RightPads,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
...@@ -34,21 +32,21 @@ template <index_t GridSize, ...@@ -34,21 +32,21 @@ template <index_t GridSize,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA, index_t GemmDataPerReadA,
index_t GemmDataPerReadB, index_t GemmDataPerReadB,
typename InBlockCopySubLengths_E_B, class InBlockCopySubLengths_E_B,
typename InBlockCopyClusterLengths_E_B, class InBlockCopyClusterLengths_E_B,
typename InBlockCopyThreadClusterArrangeOrder, class InBlockCopyThreadClusterArrangeOrder,
typename InBlockCopySrcAccessOrder, class InBlockCopySrcAccessOrder,
typename InBlockCopyDstAccessOrder, class InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B, index_t InBlockCopyDataPerAccess_B,
typename WeiBlockCopySubLengths_E_K, class WeiBlockCopySubLengths_E_K,
typename WeiBlockCopyClusterLengths_E_K, class WeiBlockCopyClusterLengths_E_K,
typename WeiBlockCopyThreadClusterArrangeOrder, class WeiBlockCopyThreadClusterArrangeOrder,
typename WeiBlockCopySrcAccessOrder, class WeiBlockCopySrcAccessOrder,
typename WeiBlockCopyDstAccessOrder, class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K, index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B> index_t OutThreadCopyDataPerAccess_B>
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
...@@ -58,27 +56,23 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -58,27 +56,23 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto I5 = Number<5>{};
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_hi_wi_global_desc = constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides()); constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
constexpr auto out_n_k_ho_wo_global_desc =
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0); constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1); constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1); constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2); constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3); constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0]; constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1]; constexpr index_t ConvStrideW = ConvStrides{}[1];
...@@ -103,67 +97,65 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -103,67 +97,65 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
constexpr index_t BBlockWork = B / BPerBlock; constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc = constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{}); make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock; const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock; const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor // input tensor
// global mem // tensor descriptor in device memory [N, Ho, Wo]
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n_ho_wo_global_desc =
in_n_c_hi_wi_global_desc, in_n_c_h_w_global_desc.Extract(I0, I2, I3)
make_tuple( .StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}), .StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{});
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); // batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
in_n_c_hip_wip_global_desc, .StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
make_tuple(PassThrough<N>{}, .Extract(Sequence<1, 2, 3>{});
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{}, // merged tensor descriptor in device memory [E, B], src of blockwise copy
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}), constexpr auto in_e_b_global_desc =
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_ConstantMergedTensorDescriptor(in_c_y_x_global_desc.Embed(in_n_ho_wo_global_desc),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); Sequence<0, 1, 2>{},
Sequence<3, 4, 5>{});
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, // memory layout descriptor in LDS [E, B], dst of blockwise copy
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// LDS mem
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto in_e_b_block_desc = constexpr auto in_e_b_block_desc =
make_native_tensor_descriptor_packed(Sequence<EPerBlock, BPerBlock>{}); make_ConstantTensorDescriptor_packed(Sequence<EPerBlock, BPerBlock>{});
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize, BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc), decltype(in_e_b_block_desc),
decltype(in_e_b_block_desc.GetLengths()), decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B, InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B, InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder, InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder, InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder, InBlockCopyDstAccessOrder,
1, 1,
1, 1,
InBlockCopyDataPerAccess_B, InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>( InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0}); {0, b_block_data_on_global}, {0, 0});
// weight tensor // weight tensor
// global mem // tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower( constexpr auto wei_e_k_global_desc =
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{}); wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
// LDS // tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned( constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{}, Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{}); Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
...@@ -173,21 +165,23 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -173,21 +165,23 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0, static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied"); "GemmDataPerReadA alignment requirement is not satisfied");
// weight blockwise copy // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize, BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K, WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder, WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder, WeiBlockCopyDstAccessOrder,
0, 0,
1, 1,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
...@@ -253,12 +247,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -253,12 +247,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread); threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
blockwise_in_copy.template Run<Float, Float, address_space_t::global>( blockwise_in_copy.template Run<Float, AddressSpace::global>(p_in_global,
p_in_global, p_in_block_double); p_in_block_double);
blockwise_wei_copy.template Run<Float, Float, address_space_t::global>( blockwise_wei_copy.template Run<Float, AddressSpace::global>(p_wei_global,
p_wei_global, p_wei_block_double); p_wei_block_double);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -289,12 +285,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -289,12 +285,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>( p_in_global, p_in_thread_buffer);
p_in_global, p_in_thread_buffer); blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
blockwise_wei_copy p_wei_global, p_wei_thread_buffer);
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
...@@ -317,9 +311,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -317,9 +311,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>( blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
p_in_global, p_in_thread_buffer); p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>( blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
p_wei_global, p_wei_thread_buffer); p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
...@@ -342,6 +336,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -342,6 +336,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
// copy output: register to global memory // copy output: register to global memory
{ {
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
// define tensor descriptor for threadwise copy
// output global descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_b_global_desc = make_ConstantMergedTensorDescriptor(
out_n_k_h_w_global_desc, Sequence<1>{}, Sequence<0, 2, 3>{});
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
...@@ -353,51 +356,46 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -353,51 +356,46 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col; b_block_data_on_global + c_thread_mtx_on_block.col;
// src descriptor // This is a hack, because slicing a merged dimension is not supported yet.
constexpr auto out_k0_k1_b0_b1_thread_desc = make_native_tensor_descriptor_packed( // This should be replaced with logic above, once slicing a merged dimension support
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat, GemmNPerThreadSubC>{}); // become available
// dst descriptor
// dst descriptor constexpr auto out_k0_k1_b_global_desc =
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; make_ConstantMergedTensorDescriptor(out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}),
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster; Sequence<1>{},
Sequence<2>{},
constexpr index_t K0 = K / K1; Sequence<0, 3, 4>{});
constexpr index_t B0 = B / B1;
// src descriptor
constexpr auto out_k_b_global_desc = transform_tensor_descriptor( constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed(
out_n_k_ho_wo_global_desc, Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat * GemmNPerThreadSubC>{});
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), using OutThreadCopySliceLengths =
make_tuple(Sequence<0>{}, Sequence<1>{})); Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
constexpr auto out_k0_k1_b0_b1_global_desc = transform_tensor_descriptor( auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
out_k_b_global_desc, decltype(out_k0_k1_b_thread_desc),
make_tuple(UnMerge<Sequence<K0, K1>>{}, UnMerge<Sequence<B0, B1>>{}), decltype(out_k0_k1_b_global_desc),
make_tuple(Sequence<0>{}, Sequence<1>{}), OutThreadCopySliceLengths,
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); arithmetic_sequence_gen<0, 3, 1>::type,
arithmetic_sequence_gen<0, 3, 1>::type,
// output threadwise copy 2,
ThreadwiseGenericTensorSliceCopy_v4r2< 2,
decltype(out_k0_k1_b0_b1_thread_desc),
decltype(out_k0_k1_b0_b1_global_desc),
decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 4, 1>::type,
3,
OutThreadCopyDataPerAccess_B, OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0}, OutThreadCopyDataPerAccess_B>({0, 0, 0},
{k_thread_data_on_global / K1, {k_thread_data_on_global / K1,
k_thread_data_on_global % K1, k_thread_data_on_global % K1,
b_thread_data_on_global / B1, b_thread_data_on_global});
b_thread_data_on_global % B1})
#if 1 for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
.template Run<Float, Float, address_space_t::generic, address_space_t::global> {
#else // tweaking threadwise_out_copy
.template Run_optimized_dst_address_calculation<Float, .template Run<Float, AddressSpace::generic, AddressSpace::global>(p_out_thread,
Float, p_out_global);
address_space_t::generic,
address_space_t::global> threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True);
#endif threadwise_out_copy.MoveDstSliceWindow(Sequence<0, 0, B1>{}, True);
(p_out_thread, p_out_global); }
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment