"vscode:/vscode.git/clone" did not exist on "fd5778695469bfe8bbdebc07f18a30742c36575e"
Unverified Commit 81c942cd authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Deprecate static kernel (#42)

* deprecate static kernels
parent b8b2d0a6
...@@ -38,8 +38,6 @@ link_libraries(${OpenMP_pthread_LIBRARY}) ...@@ -38,8 +38,6 @@ link_libraries(${OpenMP_pthread_LIBRARY})
#GPU backend #GPU backend
if(DEVICE_BACKEND STREQUAL "AMD") if(DEVICE_BACKEND STREQUAL "AMD")
find_package(HIP REQUIRED) find_package(HIP REQUIRED)
elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
enable_language(CUDA)
endif() endif()
# #
...@@ -64,13 +62,7 @@ endif() ...@@ -64,13 +62,7 @@ endif()
if(DEVICE_BACKEND STREQUAL "AMD") if(DEVICE_BACKEND STREQUAL "AMD")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp")
elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp")
endif() endif()
add_subdirectory(driver) add_subdirectory(driver)
...@@ -80,26 +72,17 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") ...@@ -80,26 +72,17 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
message("Compiling options for drivers: ${CMAKE_CXX_FLAGS}") message("Compiling options for drivers: ${CMAKE_CXX_FLAGS}")
if(DEVICE_BACKEND STREQUAL "AMD") if(DEVICE_BACKEND STREQUAL "AMD")
set(CONV_SOURCE driver/conv_driver.cpp)
set(CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cpp)
set(CONV_V2_SOURCE driver/conv_driver_v2.cpp) set(CONV_V2_SOURCE driver/conv_driver_v2.cpp)
set(CONV_BWD_DATA_V2_SOURCE driver/conv_bwd_data_driver_v2.cpp) set(CONV_BWD_DATA_V2_SOURCE driver/conv_bwd_data_driver_v2.cpp)
set(CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp) set(CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp)
elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
set(CONV_SOURCE driver/conv_driver.cu)
set(CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cu)
endif() endif()
add_executable(conv_driver ${CONV_SOURCE})
add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE})
add_executable(conv_driver_v2 ${CONV_V2_SOURCE}) add_executable(conv_driver_v2 ${CONV_V2_SOURCE})
add_executable(conv_bwd_data_driver_v2 ${CONV_BWD_DATA_V2_SOURCE}) add_executable(conv_bwd_data_driver_v2 ${CONV_BWD_DATA_V2_SOURCE})
add_executable(conv_driver_v2_olc ${CONV_V2_OLC_SOURCE}) add_executable(conv_driver_v2_olc ${CONV_V2_OLC_SOURCE})
target_include_directories(conv_driver_v2_olc PRIVATE driver/olCompiling/include/) target_include_directories(conv_driver_v2_olc PRIVATE driver/olCompiling/include/)
target_link_libraries(conv_driver PRIVATE modConv)
target_link_libraries(conv_bwd_data_driver PRIVATE modConv)
target_link_libraries(conv_driver_v2 PRIVATE modConv) target_link_libraries(conv_driver_v2 PRIVATE modConv)
target_link_libraries(conv_bwd_data_driver_v2 PRIVATE modConv) target_link_libraries(conv_bwd_data_driver_v2 PRIVATE modConv)
target_link_libraries(conv_driver_v2_olc PRIVATE modConv) target_link_libraries(conv_driver_v2_olc PRIVATE modConv)
......
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// GemmM = C * Y * X
// GemmN = N * Ho * Wo
// GemmK = K
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t ThreadGemmDataPerRead_GemmM,
index_t ThreadGemmDataPerRead_GemmN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmM,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
{
__device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
{
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
//\todo static_assert for global vector load/store
// statc_assert();
// weight tensor
constexpr auto wei_gemmk_gemmm_global_desc =
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_gemmk_gemmn_global_desc =
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
// \todo there are more combinations of Y, ConvDilationH and ConvStrideH that don't need
// atomic, find out all of them
constexpr bool not_need_atomic = (ConvStrideH >= ConvDilationH * (Y - 1) + 1) and
(ConvStrideW >= ConvDilationW * (X - 1) + 1);
constexpr auto in_memory_op =
not_need_atomic ? InMemoryDataOperation::Set : InMemoryDataOperation::AtomicAdd;
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
in_memory_op,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
ThreadGemmDataPerRead_GemmM,
ThreadGemmDataPerRead_GemmN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t EPerBlock,
index_t BPerBlock,
index_t KPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
typename OutBlockCopySubLengths_K_B_N0,
typename OutBlockCopyClusterLengths_K_B_N0,
index_t OutBlockCopySrcDataPerRead_B,
index_t OutBlockCopyDstDataPerWrite_N0,
typename WeiBlockCopySubLengths_K_E_C0,
typename WeiBlockCopyClusterLengths_K_E_C0,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_C0,
index_t InThreadCopyDstDataPerWrite_B>
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
const Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t C0 = GemmMPerThread;
constexpr index_t N0 = GemmNPerThread;
static_assert(C % C0 == 0 && N % N0 == 0, "wrong!");
constexpr index_t C1 = C / C0;
constexpr index_t N1 = N / N0;
constexpr index_t E = C1 * Y * X;
constexpr index_t B = N1 * Ho * Wo;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InThreadCopyDstDataPerWrite_B == 1)) &&
(X == 1 || ConvDilationW % InThreadCopyDstDataPerWrite_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(E % EPerBlock == 0 && B % BPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t EBlockWork = E / EPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<EBlockWork, BBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t e_block_data_on_global = block_work_id[Number<0>{}] * EPerBlock;
const index_t b_block_data_on_global = block_work_id[Number<1>{}] * BPerBlock;
// output tensor
// global tensor in global memory, src of blockwise copy
constexpr auto out_n_k_howo_global_desc =
unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3);
constexpr auto out_n0_n1_k_howo_global_desc = transform_tensor_descriptor(
out_n_k_howo_global_desc,
make_tuple(UnMerge<Sequence<N0, N1>>{}, PassThrough<K>{}, PassThrough<Ho * Wo>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
constexpr auto out_k_b_n0_global_desc = transform_tensor_descriptor(
out_n0_n1_k_howo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N1, Ho * Wo>>{}, PassThrough<N0>{}),
make_tuple(Sequence<2>{}, Sequence<1, 3>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto out_k_b_n0_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, BPerBlock, N0>{}, Number<OutBlockCopyDstDataPerWrite_N0>{});
// output tensor blockwise copy
auto blockwise_out_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(out_k_b_n0_global_desc),
decltype(out_k_b_n0_block_desc),
decltype(out_k_b_n0_block_desc.GetLengths()),
OutBlockCopySubLengths_K_B_N0,
OutBlockCopyClusterLengths_K_B_N0,
Sequence<0, 1, 2>,
Sequence<0, 1, 2>,
Sequence<0, 1, 2>,
1,
2,
OutBlockCopySrcDataPerRead_B,
OutBlockCopyDstDataPerWrite_N0,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
make_multi_index(0, b_block_data_on_global, 0), make_multi_index(0, 0, 0));
// weight tensor
// global tensor in global memory, src of blockwise copy
constexpr auto wei_k_cyx_global_desc =
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
constexpr auto wei_k_c0_e_global_desc =
transform_tensor_descriptor(wei_k_cyx_global_desc,
make_tuple(PassThrough<K>{}, UnMerge<Sequence<C0, E>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto wei_k_e_c0_global_desc = reorder_tensor_descriptor_given_lower2upper(
wei_k_c0_e_global_desc, Sequence<0, 2, 1>{});
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_k_e_c0_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, EPerBlock, C0>{}, Number<WeiBlockCopyDstDataPerWrite_C0>{});
// weight tensor blockwise copy
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(wei_k_e_c0_global_desc),
decltype(wei_k_e_c0_block_desc),
decltype(wei_k_e_c0_block_desc.GetLengths()),
WeiBlockCopySubLengths_K_E_C0,
WeiBlockCopyClusterLengths_K_E_C0,
Sequence<0, 1, 2>,
Sequence<0, 1, 2>,
Sequence<0, 1, 2>,
1,
2,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_C0,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
make_multi_index(0, e_block_data_on_global, 0), make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, EPerBlock*C0] is in LDS
// b_mtx[KPerBlocl, BPerBlock*N0] is in LDS
// c_mtx[EPerBlock*C0, BPerBlock*N0] is distributed among threads, and saved in
// register
constexpr auto a_k_ec0_block_mtx_desc = make_ConstantMatrixDescriptor(
wei_k_e_c0_block_desc.GetLength(I0),
wei_k_e_c0_block_desc.GetLength(I1) * wei_k_e_c0_block_desc.GetLength(I2),
wei_k_e_c0_block_desc.GetStride(I0));
constexpr auto b_k_bn0_block_mtx_desc = make_ConstantMatrixDescriptor(
out_k_b_n0_block_desc.GetLength(I0),
out_k_b_n0_block_desc.GetLength(I1) * out_k_b_n0_block_desc.GetLength(I2),
out_k_b_n0_block_desc.GetStride(I0));
// sanity check alignment
// TODO: this check is ad-hoc, should enforce it by enforcing alignment of
// wei_k_e_c0_block_desc and out_k_b_n0_block_desc
static_assert(a_k_ec0_block_mtx_desc.RowStride() % GemmDataPerReadB == 0, "wrong!");
static_assert(b_k_bn0_block_mtx_desc.RowStride() % GemmDataPerReadA == 0, "wrong!");
// sanity check
static_assert(EPerBlock % (GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
BPerBlock % (GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat = EPerBlock / (GemmMLevel0Cluster * GemmMLevel1Cluster);
constexpr index_t GemmNRepeat = BPerBlock / (GemmNLevel0Cluster * GemmNLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_e0e1c0_b0b1n0_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThread>{}, Number<GemmNRepeat * GemmNPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_ec0_block_mtx_desc),
decltype(b_k_bn0_block_mtx_desc),
decltype(c_e0e1c0_b0b1n0_thread_mtx_desc),
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_lds_align = math::lcm(WeiBlockCopyDstDataPerWrite_C0,
OutBlockCopyDstDataPerWrite_N0,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t out_block_space =
math::integer_least_multiple(out_k_b_n0_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_k_e_c0_block_desc.GetElementSpace(), max_lds_align);
__shared__ Float p_out_block_double[2 * out_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
AccFloat p_in_thread[c_e0e1c0_b0b1n0_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_e0e1c0_b0b1n0_thread_mtx_desc, p_in_thread);
// LDS double buffer: preload data into LDS
{
blockwise_out_copy.Run(p_out_global, p_out_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_out_block_now =
even_loop ? p_out_block_double : p_out_block_double + out_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_out_block_next =
even_loop ? p_out_block_double + out_block_space : p_out_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_out_thread_buffer[blockwise_out_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_out_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_out_copy.RunLoadThreadBuffer(p_out_global, p_out_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_out_block_now, p_in_thread);
// LDS double buffer: store next data to LDS
blockwise_out_copy.RunStoreThreadBuffer(p_out_thread_buffer, p_out_block_next);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_out_thread_buffer[blockwise_out_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_out_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
blockwise_out_copy.RunLoadThreadBuffer(p_out_global, p_out_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_wei_block_double, p_out_block_double, p_in_thread);
// LDS double buffer: store last data to LDS
blockwise_out_copy.RunStoreThreadBuffer(p_out_thread_buffer,
p_out_block_double + out_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_out_block_double + out_block_space,
p_in_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double, p_out_block_double, p_in_thread);
}
}
{
#if 1 // debug
// input: register to global memory, atomic add
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::Set
: InMemoryDataOperation::AtomicAdd;
#else
constexpr auto in_memory_op = InMemoryDataOperation::AtomicAdd;
#endif
constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t E0 = E / E1;
constexpr index_t B1 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t B0 = B / B1;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto in_e0_e1_c0_b0_b1_n0_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, 1, GemmMPerThread, GemmNRepeat, 1, GemmNPerThread>{});
// global input tensor, dst of threadwise copy
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_n0_n1_c0_c1_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(UnMerge<Sequence<N0, N1>>{},
UnMerge<Sequence<C0, C1>>{},
Embed<Hi + LeftPads::At(0) + RightPads::At(0),
Sequence<Y, Ho>,
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wi + LeftPads::At(1) + RightPads::At(1),
Sequence<X, Wo>,
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
constexpr auto in_e_c0_b_n0_global_desc = transform_tensor_descriptor(
in_n0_n1_c0_c1_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C1, Y, X>>{},
PassThrough<C0>{},
Merge<Sequence<N1, Ho, Wo>>{},
PassThrough<N0>{}),
make_tuple(Sequence<3, 4, 6>{}, Sequence<2>{}, Sequence<1, 5, 7>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
constexpr auto in_e0_e1_c0_b0_b1_n0_global_desc = transform_tensor_descriptor(
in_e_c0_b_n0_global_desc,
make_tuple(UnMerge<Sequence<E0, E1>>{},
PassThrough<C0>{},
UnMerge<Sequence<B0, B1>>{},
PassThrough<N0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t e_thread_data_on_global =
e_block_data_on_global + c_thread_mtx_on_block.row / GemmMPerThread;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / GemmNPerThread;
ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(in_e0_e1_c0_b0_b1_n0_thread_desc),
decltype(in_e0_e1_c0_b0_b1_n0_global_desc),
decltype(in_e0_e1_c0_b0_b1_n0_thread_desc.GetLengths()),
Sequence<0, 1, 2, 3, 4, 5>,
4,
1,
InThreadCopyDstDataPerWrite_B,
AddressSpace::Vgpr,
AddressSpace::Global,
in_memory_op>(make_multi_index(0, 0, 0, 0, 0, 0),
make_multi_index(e_thread_data_on_global / E1,
e_thread_data_on_global % E1,
0,
b_thread_data_on_global / B1,
b_thread_data_on_global % B1,
0))
.Run(p_in_thread, p_in_global);
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// Number of GEMMs: YTilda * XTilda
// GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice
// GemmK = K * YDotSlice * XDotSlice
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t ThreadGemmDataPerRead_GemmM,
index_t ThreadGemmDataPerRead_GemmN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmM,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
{
__host__ __device__ static constexpr index_t GetNumberOfGemm()
{
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
return YTilda * XTilda;
}
__host__ __device__ static constexpr auto GetGemmSizeImpl(index_t iYTilda, index_t iXTilda)
{
constexpr index_t N = InGlobalDesc::GetLengths()[0];
constexpr index_t C = InGlobalDesc::GetLengths()[1];
constexpr index_t Hi = InGlobalDesc::GetLengths()[2];
constexpr index_t Wi = InGlobalDesc::GetLengths()[3];
constexpr index_t K = OutGlobalDesc::GetLengths()[1];
constexpr index_t Ho = OutGlobalDesc::GetLengths()[2];
constexpr index_t Wo = OutGlobalDesc::GetLengths()[3];
constexpr index_t Y = WeiGlobalDesc::GetLengths()[2];
constexpr index_t X = WeiGlobalDesc::GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t WTilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
constexpr index_t iHTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t iWTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t iHTildaRight = math::min(
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t iWTildaRight = math::min(
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
// GemmM and GemmN
constexpr index_t GemmM = C;
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
// GemmK is different for each GEMM
index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda);
index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda);
index_t GemmK = K * YDotSlice * XDotSlice;
return Array<index_t, 3>{GemmM, GemmN, GemmK};
}
__host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id)
{
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
index_t iYTilda = gemm_id / XTilda;
index_t iXTilda = gemm_id % XTilda;
return GetGemmSizeImpl(iYTilda, iXTilda);
}
template <index_t iYTilda, index_t iXTilda>
__device__ static void RunImpl(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global)
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda);
constexpr index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda);
constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t WTilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
constexpr index_t iHTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t iWTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t iHTildaRight = math::min(
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t iWTildaRight = math::min(
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
// A matrix: weight
// weight out-of-bound check can be skipped
constexpr bool wei_skip_out_of_bound_check = true;
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Y,
Sequence<YDot, YTilda>,
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>,
wei_skip_out_of_bound_check>{},
Embed<X,
Sequence<XDot, XTilda>,
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>,
wei_skip_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto wei_k_c_ydotslice_xdotslice_global_desc = transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(
PassThrough<K>{},
PassThrough<C>{},
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydotslice_xdotslice_global_desc,
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{}, PassThrough<C>{}),
make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B matrix: output tensor
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
// situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr bool out_skip_out_of_bound_check = false;
#else
constexpr bool out_skip_out_of_bound_check = true;
#endif
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Ho,
Sequence<YDot, HTilda>,
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>,
out_skip_out_of_bound_check>{},
Embed<Wo,
Sequence<XDot, WTilda>,
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>,
out_skip_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
PassThrough<YDot>{},
PassThrough<XDot>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<iHTildaLeft, iWTildaLeft>,
Sequence<iHTildaRight, iWTildaRight>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc =
transform_tensor_descriptor(
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
PassThrough<HTildaSlice>{},
PassThrough<WTildaSlice>{},
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc,
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// C matrix: input tensor
// TODO sometimes input out-of-bound check can be skipped, find out all such situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr bool in_skip_out_of_bound_check = false;
#else
constexpr bool in_skip_out_of_bound_check = true;
#endif
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip,
Sequence<YTilda, HTilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_out_of_bound_check>{},
Embed<Wip,
Sequence<XTilda, WTilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_n_c_htildaslice_wtildaslice_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<iHTildaLeft, iWTildaLeft>,
Sequence<iHTildaRight, iWTildaRight>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<>{}, Sequence<2, 3>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_htildaslice_wtildaslice_global_desc,
make_tuple(PassThrough<C>{}, Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
ThreadGemmDataPerRead_GemmM,
ThreadGemmDataPerRead_GemmN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
template <index_t GemmId>
__device__ static void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global,
Number<GemmId>)
{
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t iYTilda = GemmId / XTilda;
constexpr index_t iXTilda = GemmId % XTilda;
static_assert(iYTilda < YTilda && iXTilda < XTilda, "wrong! iYtilda, iXtilda");
RunImpl<iYTilda, iXTilda>(p_in_global, p_wei_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// Number of GEMMs = YTilda * XTilda
// GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice
// GemmK0 = YDotSlice
// GemmK1 = XDotSlice
// GemmK2 = K
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t ThreadGemmDataPerRead_GemmM,
index_t ThreadGemmDataPerRead_GemmN,
typename GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmM,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmK2,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk
{
__host__ __device__ static constexpr index_t GetNumberOfGemm()
{
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
return YTilda * XTilda;
}
__host__ __device__ static constexpr auto GetGemmSizeImpl(index_t iYTilda, index_t iXTilda)
{
constexpr index_t N = InGlobalDesc::GetLengths()[0];
constexpr index_t Hi = InGlobalDesc::GetLengths()[1];
constexpr index_t Wi = InGlobalDesc::GetLengths()[2];
constexpr index_t C = InGlobalDesc::GetLengths()[3];
constexpr index_t Ho = OutGlobalDesc::GetLengths()[1];
constexpr index_t Wo = OutGlobalDesc::GetLengths()[2];
constexpr index_t K = OutGlobalDesc::GetLengths()[3];
constexpr index_t Y = WeiGlobalDesc::GetLengths()[1];
constexpr index_t X = WeiGlobalDesc::GetLengths()[2];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t WTilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
constexpr index_t iHTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t iWTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t iHTildaRight = math::min(
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t iWTildaRight = math::min(
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
// GemmM and GemmN
constexpr index_t GemmM = C;
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
// GemmK is different for each GEMM
index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda);
index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda);
index_t GemmK0 = YDotSlice;
index_t GemmK1 = XDotSlice;
index_t GemmK2 = K;
return make_multi_index(GemmM, GemmN, GemmK0, GemmK1, GemmK2);
}
__host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id)
{
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
index_t iYTilda = gemm_id / XTilda;
index_t iXTilda = gemm_id % XTilda;
return GetGemmSizeImpl(iYTilda, iXTilda);
}
template <index_t iYTilda, index_t iXTilda>
__device__ static void RunImpl(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global)
{
constexpr auto in_n_hi_wi_c_global_desc = InGlobalDesc{};
constexpr auto wei_k_y_x_c_global_desc = WeiGlobalDesc{};
constexpr auto out_n_ho_wo_k_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_hi_wi_c_global_desc.GetLengths()[0];
constexpr index_t Hi = in_n_hi_wi_c_global_desc.GetLengths()[1];
constexpr index_t Wi = in_n_hi_wi_c_global_desc.GetLengths()[2];
constexpr index_t C = in_n_hi_wi_c_global_desc.GetLengths()[3];
constexpr index_t Ho = out_n_ho_wo_k_global_desc.GetLengths()[1];
constexpr index_t Wo = out_n_ho_wo_k_global_desc.GetLengths()[2];
constexpr index_t K = out_n_ho_wo_k_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_y_x_c_global_desc.GetLengths()[1];
constexpr index_t X = wei_k_y_x_c_global_desc.GetLengths()[2];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
constexpr index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda);
constexpr index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda);
constexpr index_t HTilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
constexpr index_t WTilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
constexpr index_t iHTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
constexpr index_t iWTildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
constexpr index_t iHTildaRight = math::min(
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t iWTildaRight = math::min(
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
// A matrix: weight
// weight out-of-bound check can be skipped
constexpr bool wei_skip_out_of_bound_check = true;
constexpr auto wei_k_ydot_ytilda_xdot_xtilda_c_global_desc = transform_tensor_descriptor(
wei_k_y_x_c_global_desc,
make_tuple(PassThrough<K>{},
Embed<Y,
Sequence<YDot, YTilda>,
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>,
wei_skip_out_of_bound_check>{},
Embed<X,
Sequence<XDot, XTilda>,
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>,
wei_skip_out_of_bound_check>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
constexpr auto wei_k_ydotslice_xdotslice_c_global_desc = transform_tensor_descriptor(
wei_k_ydot_ytilda_xdot_xtilda_c_global_desc,
make_tuple(
PassThrough<K>{},
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<>{}, Sequence<3>{}));
constexpr auto wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc =
reorder_tensor_descriptor_given_lower2upper(wei_k_ydotslice_xdotslice_c_global_desc,
Sequence<2, 0, 1, 3>{});
// B matrix: output tensor
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
// situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr bool out_skip_out_of_bound_check = false;
#else
constexpr bool out_skip_out_of_bound_check = true;
#endif
constexpr auto out_n_ydot_htilda_xdot_wtilda_k_global_desc = transform_tensor_descriptor(
out_n_ho_wo_k_global_desc,
make_tuple(PassThrough<N>{},
Embed<Ho,
Sequence<YDot, HTilda>,
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>,
out_skip_out_of_bound_check>{},
Embed<Wo,
Sequence<XDot, WTilda>,
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>,
out_skip_out_of_bound_check>{},
PassThrough<K>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
constexpr auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc =
transform_tensor_descriptor(
out_n_ydot_htilda_xdot_wtilda_k_global_desc,
make_tuple(
PassThrough<N>{},
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<iHTildaLeft, iWTildaLeft>,
Sequence<iHTildaRight, iWTildaRight>>{},
PassThrough<K>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}));
constexpr auto out_gemmk0_gemmk1_gemmk2_gemmn_global_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc,
make_tuple(PassThrough<YDotSlice>{},
PassThrough<XDotSlice>{},
PassThrough<K>{},
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// C matrix: input tensor
// TODO sometimes input out-of-bound check can be skipped, find out all such situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr bool in_skip_out_of_bound_check = false;
#else
constexpr bool in_skip_out_of_bound_check = true;
#endif
constexpr auto in_n_hip_wip_c_global_desc = transform_tensor_descriptor(
in_n_hi_wi_c_global_desc,
make_tuple(PassThrough<N>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_out_of_bound_check>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[1];
constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[2];
constexpr auto in_n_ytilda_htilda_xtilda_wtilda_c_global_desc = transform_tensor_descriptor(
in_n_hip_wip_c_global_desc,
make_tuple(PassThrough<N>{},
Embed<Hip,
Sequence<YTilda, HTilda>,
Sequence<ConvDilationH, ConvStrideH, 0>,
in_skip_out_of_bound_check>{},
Embed<Wip,
Sequence<XTilda, WTilda>,
Sequence<ConvDilationW, ConvStrideW, 0>,
in_skip_out_of_bound_check>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
constexpr auto in_n_htildaslice_wtildaslice_c_global_desc = transform_tensor_descriptor(
in_n_ytilda_htilda_xtilda_wtilda_c_global_desc,
make_tuple(PassThrough<N>{},
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
Slice<Sequence<HTilda, WTilda>,
Sequence<iHTildaLeft, iWTildaLeft>,
Sequence<iHTildaRight, iWTildaRight>>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_htildaslice_wtildaslice_c_global_desc,
make_tuple(PassThrough<C>{}, Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// call GEMM
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v2<
GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc),
decltype(out_gemmk0_gemmk1_gemmk2_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
ThreadGemmDataPerRead_GemmM,
ThreadGemmDataPerRead_GemmN,
GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
3,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN,
Sequence<0, 1, 3, 2>,
Sequence<0, 1, 3, 2>,
2,
GemmBBlockCopySrcDataPerRead_GemmK2,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
template <index_t GemmId>
__device__ static void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global,
Number<GemmId>)
{
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
constexpr index_t iYTilda = GemmId / XTilda;
constexpr index_t iXTilda = GemmId % XTilda;
static_assert(iYTilda < YTilda && iXTilda < XTilda, "wrong! iYtilda, iXtilda");
RunImpl<iYTilda, iXTilda>(p_in_global, p_wei_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccDataType,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmNRepeat,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
typename InBlockCopySubLengths_E_N1_B_N2,
typename InBlockCopyClusterLengths_E_N1_B_N2,
typename InBlockCopyThreadClusterArrangeOrder,
typename InBlockCopySrcAccessOrder,
typename InBlockCopyDstAccessOrder,
index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2,
typename WeiBlockCopySubLengths_E_K,
typename WeiBlockCopyClusterLengths_E_K,
typename WeiBlockCopyThreadClusterArrangeOrder,
typename WeiBlockCopySrcAccessOrder,
typename WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThread;
static_assert(
(N1 * N2 * BPerBlock) % (GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
"wrong!");
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2);
constexpr index_t B = N0 * Ho * Wo;
constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_id[I0] * KPerBlock;
const index_t b_block_data_on_global = block_work_id[I1] * BPerBlock;
// input tensor
// global tensor in global memory
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
PassThrough<C>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
// global tensor in global memory, src of blockwise copy
constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor(
in_n0_n1_n2_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{},
PassThrough<N1>{},
Merge<Sequence<N0, Ho, Wo>>{},
PassThrough<N2>{}),
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_n1_b_n2_block_desc = make_native_tensor_descriptor_aligned(
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not satisfied");
// input tensor blockwise copy
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(in_e_n1_b_n2_global_desc),
decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
2,
3,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
make_multi_index(0, 0, b_block_data_on_global, 0), make_multi_index(0, 0, 0, 0));
// weight tensor
// global tensor in global memory, src of blockwise copy
// It is constructed differently, depending on whether forward or backward weight
// convolution
constexpr auto wei_e_k_global_desc =
transform_tensor_descriptor(unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3),
make_tuple(Merge<Sequence<C, Y * X>>{}, PassThrough<K>{}),
make_tuple(Sequence<1, 2>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// weight tensor blockwise copy
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
make_multi_index(0, k_block_data_on_global), make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc = make_ConstantMatrixDescriptor(
in_e_n1_b_n2_block_desc.GetLength(I0),
in_e_n1_b_n2_block_desc.GetLength(I1) * in_e_n1_b_n2_block_desc.GetLength(I2) *
in_e_n1_b_n2_block_desc.GetLength(I3),
in_e_n1_b_n2_block_desc.GetStride(I0));
// sanity check
static_assert(KPerBlock % (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThread>{}, Number<GemmNRepeat * GemmNPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k1_n1n2_thread_mtx_desc),
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
AccDataType p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_n1n2_thread_mtx_desc, p_out_thread);
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
constexpr auto in_block_slice_copy_steps = Sequence<EPerBlock, 0, 0, 0>{};
constexpr auto wei_block_slice_copy_steps = Sequence<EPerBlock, 0>{};
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True);
blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True);
blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True);
__syncthreads();
// LDS double buffer: load last data from device mem
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store last data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
}
}
// copy output: register to global memory
{
constexpr index_t K1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t K0 = K / K1;
// define output tensor descriptor for threadwise copy
// thread output tensor, src of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, GemmMPerThread, N1, 1, N2>{});
// global output tensor
constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
UnMerge<Sequence<K0, K1>>{},
PassThrough<Ho>{},
PassThrough<Wo>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}, Sequence<6>{}));
// global output tensor, dst of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_global_desc = transform_tensor_descriptor(
out_n0_n1_n2_k0_k1_ho_wo_global_desc,
make_tuple(PassThrough<K0>{},
PassThrough<K1>{},
PassThrough<N1>{},
Merge<Sequence<N0, Ho, Wo>>{},
PassThrough<N2>{}),
make_tuple(Sequence<3>{},
Sequence<4>{},
Sequence<1>{},
Sequence<0, 5, 6>{},
Sequence<2>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(out_k0_k1_n1_b_n2_thread_desc),
decltype(out_k0_k1_n1_b_n2_global_desc),
decltype(out_k0_k1_n1_b_n2_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 5, 1>::type,
3,
1,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
InMemoryDataOperation::Set>(make_multi_index(0, 0, 0, 0, 0),
make_multi_index(k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
0,
b_thread_data_on_global,
0))
.Run(p_out_thread, p_out_global);
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t ThreadGemmDataPerRead_GemmM,
index_t ThreadGemmDataPerRead_GemmN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmK,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
#if 0
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0) &&
InLeftPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0 &&
InRightPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0,
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
// weight tensor
constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_gemmm_gemmn_global_desc =
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm_gemmn_global_desc),
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
ThreadGemmDataPerRead_GemmM,
ThreadGemmDataPerRead_GemmN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockCopySrcDataPerRead_GemmK,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t ThreadGemmDataPerRead_GemmM,
index_t ThreadGemmDataPerRead_GemmN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmK,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmK,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmM1>
struct GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_hi_wi_c_global_desc = InGlobalDesc{};
constexpr auto wei_k_y_x_c_global_desc = WeiGlobalDesc{};
constexpr auto out_n_ho_wo_k_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_hi_wi_c_global_desc.GetLengths()[I0];
constexpr index_t Hi = in_n_hi_wi_c_global_desc.GetLengths()[I1];
constexpr index_t Wi = in_n_hi_wi_c_global_desc.GetLengths()[I2];
constexpr index_t C = in_n_hi_wi_c_global_desc.GetLengths()[I3];
constexpr index_t K = out_n_ho_wo_k_global_desc.GetLengths()[I3];
constexpr index_t Ho = out_n_ho_wo_k_global_desc.GetLengths()[I1];
constexpr index_t Wo = out_n_ho_wo_k_global_desc.GetLengths()[I2];
constexpr index_t Y = wei_k_y_x_c_global_desc.GetLengths()[I1];
constexpr index_t X = wei_k_y_x_c_global_desc.GetLengths()[I2];
constexpr index_t ConvStrideH = ConvStrides{}[I0];
constexpr index_t ConvStrideW = ConvStrides{}[I1];
constexpr index_t ConvDilationH = ConvDilations{}[I0];
constexpr index_t ConvDilationW = ConvDilations{}[I1];
// weight tensor
constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_y_x_c_global_desc, I1, I3), Sequence<1, 0>{});
// input tensor
constexpr auto in_n_hip_wip_c_global_desc =
transform_tensor_descriptor(in_n_hi_wi_c_global_desc,
make_tuple(PassThrough<N>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[I1];
constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[I2];
constexpr auto in_n_y_ho_x_wo_c_global_desc = transform_tensor_descriptor(
in_n_hip_wip_c_global_desc,
make_tuple(PassThrough<N>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_global_desc,
make_tuple(Merge<Sequence<Y, X, C>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
unfold_tensor_descriptor(out_n_ho_wo_k_global_desc, I0, I2),
make_tuple(PassThrough<K>{}, Merge<Sequence<N * Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm_gemmn_global_desc),
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
ThreadGemmDataPerRead_GemmM,
ThreadGemmDataPerRead_GemmN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockCopySrcDataPerRead_GemmK,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmBBlockCopySrcDataPerRead_GemmK,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<2, 3, 0, 1>,
1,
GemmCThreadCopyDstDataPerWrite_GemmM1>{};
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
namespace ck {
template <index_t NRow_, index_t NCol_, index_t RowStride_>
struct ConstantMatrixDescriptor
{
__host__ __device__ constexpr ConstantMatrixDescriptor()
{
static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!");
}
__host__ __device__ static constexpr index_t NRow() { return NRow_; }
__host__ __device__ static constexpr index_t NCol() { return NCol_; }
__host__ __device__ static constexpr index_t RowStride() { return RowStride_; }
__host__ __device__ static constexpr auto GetLengths() { return Sequence<NRow_, NCol_>{}; }
__host__ __device__ static constexpr index_t GetElementSize() { return NRow_ * NCol_; }
__host__ __device__ static constexpr index_t GetElementSpace() { return NRow_ * RowStride_; }
__host__ __device__ static index_t GetOffsetFromMultiIndex(index_t irow, index_t icol)
{
return irow * RowStride_ + icol;
}
__host__ __device__ static index_t CalculateOffset(index_t irow, index_t icol)
{
return GetOffsetFromMultiIndex(irow, icol);
}
template <index_t SubNRow, index_t SubNCol>
__host__ __device__ static constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
Number<SubNCol>)
{
return ConstantMatrixDescriptor<SubNRow, SubNCol, RowStride_>{};
}
};
template <index_t NRow, index_t NCol>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor_packed(Number<NRow>, Number<NCol>)
{
return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
}
template <index_t NRow, index_t NCol, index_t RowStride>
__host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
{
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
}
template <typename... Ts>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(NativeTensorDescriptor<Ts...>)
{
using TDesc = NativeTensorDescriptor<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
static_assert(TDesc::GetStrides()[1] == 1, "wrong");
return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
TDesc::GetLengths()[1],
TDesc::GetStrides()[0]>{};
}
template <typename TDesc>
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
{
printf(
"%s NRow %u NCol %u RowStride %u\n", s, TDesc::NRow(), TDesc::NCol(), TDesc::RowStride());
}
} // namespace ck
#endif
...@@ -2,50 +2,10 @@ ...@@ -2,50 +2,10 @@
#define CK_CLUSTER_DESCRIPTOR_HPP #define CK_CLUSTER_DESCRIPTOR_HPP
#include "common_header.hpp" #include "common_header.hpp"
// TODO remove dependency on deprecated tensor descriptor
#include "tensor_descriptor.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
namespace ck { namespace ck {
// a cluster map 1d index to N-d index
template <typename Lengths, typename ArrangeOrder>
struct ClusterDescriptor
{
static constexpr index_t nDim = Lengths::Size();
static constexpr auto mDesc = transform_tensor_descriptor(
make_native_tensor_descriptor_packed(Lengths{}),
make_tuple(Merge<decltype(Lengths::ReorderGivenNew2Old(ArrangeOrder{}))>{}),
make_tuple(ArrangeOrder{}),
make_tuple(Sequence<0>{}));
__host__ __device__ constexpr ClusterDescriptor()
{
static_assert(Lengths::Size() == nDim && ArrangeOrder::Size() == nDim,
"wrong! size not the same");
static_assert(is_valid_sequence_map<ArrangeOrder>{}, "wrong! ArrangeOrder is wrong");
}
__host__ __device__ static constexpr index_t GetElementSize() { return mDesc.GetElementSize(); }
__host__ __device__ static constexpr auto CalculateClusterIndex(index_t idx_1d)
{
return mDesc.CalculateLowerIndex(MultiIndex<1>{idx_1d});
}
};
template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor(
Lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
{
return ClusterDescriptor<Lengths, decltype(order)>{};
}
#if 1
template <typename Lengths, template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type> typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor_v2( __host__ __device__ constexpr auto make_cluster_descriptor_v2(
...@@ -68,7 +28,6 @@ __host__ __device__ constexpr auto make_cluster_descriptor_v2( ...@@ -68,7 +28,6 @@ __host__ __device__ constexpr auto make_cluster_descriptor_v2(
return make_single_stage_tensor_adaptor( return make_single_stage_tensor_adaptor(
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids)); make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
} }
#endif
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_DIMENSION_HPP
#define CK_DIMENSION_HPP
#include "common_header.hpp"
namespace ck {
template <index_t Length, index_t Stride>
struct NativeDimension
{
__host__ __device__ static constexpr auto GetLength() { return Number<Length>{}; }
__host__ __device__ static constexpr auto GetStride() { return Number<Stride>{}; }
};
} // namespace ck
#endif
#ifndef CK_MULTI_INDEX_TRANSFORM_HPP
#define CK_MULTI_INDEX_TRANSFORM_HPP
#include "common_header.hpp"
#include "multi_index.hpp"
namespace ck {
template <index_t Length>
struct PassThrough
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<1>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<1>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<Length>{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
return idx_up;
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return idx_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
};
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
// necessary
// However, the check will be skipped if SkipIsValidCheck is set to true by user
// LowerLengths: Sequence<...>
template <typename LowerLengths,
typename LeftPads,
typename RightPads,
bool SkipIsValidCheck = false>
struct Pad
{
static constexpr index_t nDim = LowerLengths::Size();
using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>;
__host__ __device__ constexpr Pad()
{
static_assert(LowerLengths::GetSize() == nDim && LeftPads::GetSize() == nDim &&
RightPads::GetSize() == nDim,
"wrong! # of dimensions not consistent");
}
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return LowerLengths{} + LeftPads{} + RightPads{};
}
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
return idx_up - LeftPads{};
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return idx_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
// skip valid check if user request it
if(SkipIsValidCheck)
{
return true;
}
bool flag = true;
for(index_t i = 0; i < nDim; ++i)
{
flag = flag && LeftPads::At(i) == 0 && RightPads::At(i) == 0;
}
return flag;
}
};
// LowerLengths: Sequence<...>
// SliceBegins: Sequence<...>
// SliceEnds: Sequence<...>
template <typename LowerLengths, typename SliceBegins, typename SliceEnds>
struct Slice
{
static constexpr index_t nDim = LowerLengths::Size();
using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>;
__host__ __device__ constexpr Slice()
{
static_assert(LowerLengths::GetSize() == nDim && SliceBegins::GetSize() == nDim &&
SliceEnds::GetSize() == nDim,
"wrong! # of dimensions not consistent");
#if 0
// TODO: would not compile, error on constexpr
static_for<0, nDim, 1>{}([&](auto idim) {
static_assert(SliceBegins::At(idim) <= SliceEnds::At(idim) &&
SliceBegins::At(idim) >= 0 &&
SliceEnds::At(idim) <= LowerLengths::At(idim),
"wrong! Slice config is wrong");
});
#endif
}
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return SliceEnds{} - SliceBegins{};
}
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
return idx_up + SliceBegins{};
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return idx_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
};
// LowerLengths: Sequence<...>
template <typename LowerLengths>
struct Merge
{
static constexpr index_t nDimLow = LowerLengths::Size();
static constexpr index_t nDimUp = 1;
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return Sequence<reduce_on_sequence(
LowerLengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
}
// emulate constexpr lambda
template <typename PseudoLowStrides>
struct lambda_CalculateLowerIndex
{
index_t& itmp;
LowerIndex& idx_low;
__host__ __device__ constexpr lambda_CalculateLowerIndex(index_t& itmp_,
LowerIndex& idx_low_)
: itmp(itmp_), idx_low(idx_low_)
{
}
template <typename IDim>
__host__ __device__ constexpr void operator()(IDim idim) const
{
constexpr index_t stride = PseudoLowStrides::At(idim);
idx_low(idim) = itmp / stride;
itmp -= idx_low[idim] * stride;
}
};
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low;
index_t itmp = idx_up[Number<0>{}];
constexpr auto pseudo_low_strides =
reverse_inclusive_scan_sequence(
LowerLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
static_for<0, nDimLow - 1, 1>{}(
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
idx_low(Number<nDimLow - 1>{}) = itmp / pseudo_low_strides[Number<nDimLow - 1>{}];
return idx_low;
}
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
// If idx_up_diff is known at compile-time, many calculations can be optimized
// away by compiler
// This function assume idx_low_old is not out-of-bound
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& idx_low_old)
{
if(idx_up_diff[Number<0>{}] == 0)
{
return make_zero_multi_index<nDimLow>();
}
else
{
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
// If idx_up_diff is known at compile-time, the calculation can
// be done at compile-time. However, if idx_up_diff is only known
// at run-time, then the calculation will also be computed at
// run-time, and can be very expensive.
LowerIndex idx_low_diff_tmp = CalculateLowerIndex(idx_up_diff);
// find out the last low dimension that changed
index_t last_changed_low_dim = 0;
static_for<0, nDimLow, 1>{}([&](auto i) {
if(idx_low_diff_tmp[i] != 0)
{
last_changed_low_dim = i;
}
});
LowerIndex idx_low_new = idx_low_old + idx_low_diff_tmp;
if(idx_up_diff[Number<0>{}] > 0)
{
// do carry check on each low dimension in reversed order
// starting from the first digit that changed
// don't check the highest dimension
bool carry = false;
static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
if(i <= last_changed_low_dim)
{
if(carry)
{
++idx_low_new(i);
}
carry = false;
if(idx_low_new[i] >= LowerLengths::At(i))
{
idx_low_new(i) -= LowerLengths::At(i);
carry = true;
}
}
});
// highest dimension, no out-of-bound check
if(carry)
{
++idx_low_new(Number<0>{});
}
}
else
{
// do borrow check on each low dimension in reversed order
// starting from the first digit that changed
// don't check the highest dimension
bool borrow = false;
static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
if(i <= last_changed_low_dim)
{
if(borrow)
{
--idx_low_new(i);
}
borrow = false;
if(idx_low_new[i] < 0)
{
idx_low_new(i) += LowerLengths::At(i);
borrow = true;
}
}
});
// highest dimension, no out-of-bound check
if(borrow)
{
--idx_low_new(Number<0>{});
}
}
return idx_low_new - idx_low_old;
}
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
};
// UpperLengths: Sequence<...>
template <typename UpperLengths>
struct UnMerge
{
static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpperLengths::Size();
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low = make_multi_index(0);
constexpr auto pseudo_up_strides =
reverse_inclusive_scan_sequence(
UpperLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low(Number<0>{}) += idx_up[idim] * pseudo_up_strides[idim]; });
return idx_low;
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return CalculateLowerIndex(idx_up_diff);
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
};
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
// necessary
// However, the check will be skipped if SkipIsValidCheck is set to true by user
// UpperLengths: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template <index_t LowerLength,
typename UpperLengths,
typename Coefficients,
bool SkipIsValidCheck = false>
struct Embed
{
static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpperLengths::Size();
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ constexpr Embed()
{
static_assert(UpperLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
"wrong! # of dimensions not consistent");
}
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low = make_multi_index(Coefficients{}[Number<nDimUp>{}]);
static_for<0, nDimUp, 1>{}(
[&](auto i) { idx_low(Number<0>{}) += idx_up[i] * Coefficients{}[i]; });
return idx_low;
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
LowerIndex idx_low_diff = make_multi_index(0);
static_for<0, nDimUp, 1>{}(
[&](auto i) { idx_low_diff(Number<0>{}) += idx_up_diff[i] * Coefficients{}[i]; });
return idx_low_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
// skip valid check if user request it
if(SkipIsValidCheck)
{
return true;
}
bool flag = true;
index_t ncorner = 1;
for(index_t idim = 0; idim < nDimUp; ++idim)
{
ncorner *= 2;
}
// loop over each corner of the upper tensor
for(index_t icorner = 0; icorner < ncorner; ++icorner)
{
// generate upper index for each corner
auto idx_up = make_zero_multi_index<nDimUp>();
index_t itmp = icorner;
static_for<nDimUp, 0, -1>{}([&](auto idim) {
auto idim_m1 = idim - Number<1>{};
idx_up(idim_m1) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim_m1) - 1;
itmp /= 2;
});
// calculate lower index
auto idx_low = CalculateLowerIndex(idx_up);
// judge if lower index is valid
flag = flag && idx_low[Number<0>{}] >= 0 && idx_low[Number<0>{}] < LowerLength;
}
return flag;
}
};
// LowerLengths: Sequence<...>
// LowerFreezePoint: Sequence<...>
template <typename LowerLengths, typename LowerFreezePoint>
struct Freeze
{
static constexpr index_t nDimLow = LowerLengths::Size();
static constexpr index_t nDimUp = 0;
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ constexpr Freeze()
{
// TODO: sanity check: LowerFreezePoint should be within range of LowerLengths
}
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<0>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<>{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& /*idx_up*/)
{
return to_multi_index(LowerFreezePoint{});
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& /* idx_up_diff */,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return make_zero_multi_index<nDimLow>();
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
};
} // namespace ck
#endif
#ifndef CK_PRINT_TENSOR_DESCRIPTOR_HPP
#define CK_PRINT_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
namespace ck {
template <typename... NativeDimensions>
__host__ __device__ void
print_tensor_descriptor(const char* s, const NativeTensorDescriptor<NativeDimensions...>& desc)
{
print_tensor_descriptor_impl(s, desc.GetLengths(), desc.GetStrides());
}
template <typename... Ts>
__host__ __device__ void print_tensor_descriptor(const char* s,
const TransformedTensorDescriptor<Ts...>& desc)
{
print_tensor_descriptor_impl(s, desc.GetLengths());
}
template <index_t... Lengths, index_t... Strides>
__host__ __device__ void
print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strides...>)
{
constexpr index_t nDim = sizeof...(Lengths);
static_assert(nDim > 0 && nDim <= 12, "wrong!");
static_if<nDim == 1>{}([&](auto) {
printf("%s dim %u, lengths {%u}, strides {%u}\n", s, nDim, Lengths..., Strides...);
});
static_if<nDim == 2>{}([&](auto) {
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, nDim, Lengths..., Strides...);
});
static_if<nDim == 3>{}([&](auto) {
printf(
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, nDim, Lengths..., Strides...);
});
static_if<nDim == 4>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 5>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 6>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 7>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 8>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 9>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 10>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
"%u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 11>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
"%u %u "
"%u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
static_if<nDim == 12>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
"%u %u %u %u "
"%u %u %u}\n",
s,
nDim,
Lengths...,
Strides...);
});
}
template <index_t... Lengths>
__host__ __device__ void print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>)
{
constexpr index_t nDim = sizeof...(Lengths);
static_assert(nDim > 0 && nDim <= 12, "wrong!");
static_if<nDim == 1>{}([&](auto) { printf("%s dim %u, lengths {%u}\n", s, nDim, Lengths...); });
static_if<nDim == 2>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 3>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 4>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 5>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 6>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u}, \n", s, nDim, Lengths...); });
static_if<nDim == 7>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 8>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
static_if<nDim == 9>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
static_if<nDim == 10>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
static_if<nDim == 11>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
static_if<nDim == 12>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
}
} // namespace ck
#endif
#ifndef CK_TENSOR_COORDINATE_HPP
#define CK_TENSOR_COORDINATE_HPP
#include "common_header.hpp"
#include "dimension.hpp"
#include "multi_index_transform.hpp"
#include "tensor_descriptor.hpp"
namespace ck {
// A "tensor cooridnate" is an opaque object that represents a "point of location" inside a tensor
// At the bare minimun, user should be able to query the following information from a tensor
// coordinate:
// 1. Tensor descriptor
// 2. Location, represented in the form of multi-index
// 3. Location, represented in the form of the offset to the origin of the tensor
// 4. If the location is inside invalid area or not, i.e. the padding area of an implicitly padded
// tensor is considered invalid, because the padding area doesn't have any physical memory
// allocation
// A tensor cooridnate also provides following functionality:
// 1. Given step size in each dimension, update itself, or return a new tensor cooridnate, so user
// can freely move the "point of location" inside the tensor
// wrapper class for NativeTensorCoordinate and TransformedTensorCoordinate
template <typename TensorDesc>
struct TensorCoordinate;
// tensor coordinate for native tensor
template <typename NativeTensorDesc>
struct NativeTensorCoordinate
{
using type = NativeTensorCoordinate;
using tensor_desc_type = NativeTensorDesc;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__host__ __device__ constexpr NativeTensorCoordinate(Index idx)
: mIndex(idx), mOffset(tensor_desc_type::CalculateOffset(idx))
{
}
template <typename... Xs>
__host__ __device__ constexpr NativeTensorCoordinate(Xs... xs)
: NativeTensorCoordinate(make_multi_index(xs...))
{
}
template <index_t... Xs>
__host__ __device__ constexpr NativeTensorCoordinate(Sequence<Xs...>)
: NativeTensorCoordinate(make_mutli_index(Xs...))
{
}
__host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
__host__ __device__ constexpr const Index& GetUpperIndex() const { return mIndex; }
__host__ __device__ constexpr const Index& GetIndex() const { return mIndex; }
__host__ __device__ constexpr const index_t& GetOffset() const { return mOffset; }
__host__ __device__ constexpr type operator+=(const Index& idx_diff)
{
// mIndex is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndex += idx_diff;
mOffset += tensor_desc_type::CalculateOffsetDiff(idx_diff);
return *this;
}
__host__ __device__ constexpr type operator-=(const Index& idx_diff)
{
// mIndex is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndex -= idx_diff;
mOffset -= tensor_desc_type::CalculateOffsetDiff(idx_diff);
return *this;
}
__host__ __device__ constexpr type operator+(const Index& idx_diff) const
{
type coord = *this;
coord += idx_diff;
return coord;
}
__host__ __device__ constexpr type operator-(const Index& idx_diff) const
{
type coord = *this;
coord -= idx_diff;
return coord;
}
__host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff)
{
return tensor_desc_type::CalculateOffsetDiff(idx_diff);
}
// evaluated at run-time
__host__ __device__ constexpr bool IsUpperIndexValid() const
{
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
}
// evaluated at run-time
__host__ __device__ constexpr bool IsOffsetValid() const
{
// For native tensor, offset is valid if upper-index is valid
return IsUpperIndexValid();
}
// evaluated at compile-time
__host__ __device__ static constexpr bool IsOffsetValidAssumingUpperIndexIsValid()
{
return true;
}
private:
// mIndex may be saved and updated, however, the value of some (or all) of its entries may
// never be used. Compiler should be able to remove these entries as well as its calculation
// as dead code.
// TODO: make sure compiler indeed remove these dead code
Index mIndex;
index_t mOffset;
};
// tensor coordinate for transformed tensor
template <typename TransformedTensorDesc>
struct TransformedTensorCoordinate
{
using tensor_desc_type = TransformedTensorDesc;
using LowerCoord =
typename TensorCoordinate<decltype(tensor_desc_type::GetLowerTensorDescriptor())>::type;
using UpperCoord = TransformedTensorCoordinate;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
using UpperIndex = MultiIndex<nDim>;
__host__ __device__ constexpr TransformedTensorCoordinate(UpperIndex idx)
: mIndexUp{idx}, mCoordLow{tensor_desc_type::CalculateLowerIndex(idx)}
{
}
template <typename... Xs>
__host__ __device__ constexpr TransformedTensorCoordinate(Xs... xs)
: TransformedTensorCoordinate(UpperIndex{xs...})
{
}
template <index_t... Xs>
__host__ __device__ constexpr TransformedTensorCoordinate(Sequence<Xs...>)
: TransformedTensorCoordinate(UpperIndex{Xs...})
{
}
__host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
__host__ __device__ constexpr const LowerCoord& GetLowerCoordinate() const { return mCoordLow; }
__host__ __device__ constexpr const UpperIndex& GetUpperIndex() const { return mIndexUp; }
__host__ __device__ constexpr const UpperIndex& GetIndex() const { return GetUpperIndex(); }
__host__ __device__ constexpr const index_t& GetOffset() const
{
return GetLowerCoordinate().GetOffset();
}
__host__ __device__ constexpr UpperCoord operator+=(const UpperIndex& idx_up_diff)
{
// For transformation of multi-index difference, not all transformation functions need to
// know the old lower-index or the old upper-index. We pass both of them to the
// transformation function. The transformation function itself decides to use them or not.
mCoordLow += tensor_desc_type::CalculateLowerIndexDiff(
idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex());
// mIndexUp is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndexUp += idx_up_diff;
return *this;
}
__host__ __device__ constexpr UpperCoord operator-=(const UpperIndex& idx_up_diff)
{
mCoordLow -= tensor_desc_type::CalculateLowerIndexDiff(
idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex());
// mIndex is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndexUp -= idx_up_diff;
return *this;
}
__host__ __device__ constexpr UpperCoord operator+(const UpperIndex& idx_up_diff) const
{
UpperCoord coord_up = *this;
coord_up += idx_up_diff;
return coord_up;
}
__host__ __device__ constexpr UpperCoord operator-(const UpperIndex& idx_up_diff) const
{
UpperCoord coord_up = *this;
coord_up -= idx_up_diff;
return coord_up;
}
// Calculate offset diff without updating tensor-coordinate
// If idx_up_diff is know at compile time, and has only non-zero entries on linear dimensions,
// then all calculation can be done at compile-time.
// TODO: this function is not compiled to expected ISA
__host__ __device__ constexpr index_t CalculateOffsetDiff(const UpperIndex& idx_up_diff) const
{
// For transformation of multi-index difference, not all transformation functions need to
// know the old lower-index or the old upper-index. We pass both of them to the
// transformation function. The transformation function itself decides to use them or not.
const auto idx_low_diff = tensor_desc_type::CalculateLowerIndexDiff(
idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex());
return GetLowerCoordinate().CalculateOffsetDiff(idx_low_diff);
}
// evaluated at run-time
__host__ __device__ constexpr bool IsUpperIndexValid() const
{
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
}
// evaluted at run-time
__host__ __device__ constexpr bool IsOffsetValid() const
{
return IsUpperIndexValid() && GetLowerCoordinate().IsOffsetValid();
}
// most evaluatation is done at comile-time
__host__ __device__ constexpr bool IsLowerIndexValidAssumingUpperIndexIsValid() const
{
return tensor_desc_type::IsLowerIndexValidAssumingUpperIndexIsValid(
GetLowerCoordinate().GetIndex());
}
// most evaluatation is done at comile-time
__host__ __device__ constexpr bool IsOffsetValidAssumingUpperIndexIsValid() const
{
return IsLowerIndexValidAssumingUpperIndexIsValid() &&
GetLowerCoordinate().IsOffsetValidAssumingUpperIndexIsValid();
}
private:
// mIndexUp may be calculated and updated, however, the value of some (or all) of its entries
// may
// never be used. Compiler should be able to remove these entries as well as its calculation
// as dead code.
// TODO: make sure compiler indeed remove these dead code
UpperIndex mIndexUp;
LowerCoord mCoordLow;
};
template <typename TensorDesc>
struct TensorCoordinate
{
private:
template <typename... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
{
return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>(
make_zero_multi_index<TensorDesc::GetNumOfDimension()>());
}
template <typename... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
{
return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>(
make_zero_multi_index<TensorDesc::GetNumOfDimension()>());
}
public:
using type = decltype(MakeDummyTensorCoordinate(TensorDesc{}));
};
} // namespace ck
#endif
#ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "dimension.hpp"
#include "multi_index_transform.hpp"
namespace ck {
// tensor descriptor for "native tensor"
// A "native tensor" is a "true" tensor that can be represented by Lengths and Strides
template <typename... NativeDimensions>
struct NativeTensorDescriptor
{
using type = NativeTensorDescriptor;
static constexpr index_t nDim = sizeof...(NativeDimensions);
static constexpr auto mDimensions = make_tuple(NativeDimensions{}...);
using Index = MultiIndex<nDim>;
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{
return mDimensions.At(Number<IDim>{}).GetLength();
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
{
return mDimensions.At(Number<IDim>{}).GetStride();
}
template <index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
{
return Sequence<GetLength(Number<IDims>{})...>{};
}
template <index_t... IDims>
__host__ __device__ static constexpr auto GetStrides(Sequence<IDims...>)
{
return Sequence<GetStride(Number<IDims>{})...>{};
}
template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
{
return GetLengths(Sequence<IDim, IDims...>{});
}
template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetStrides(Number<IDim>, Number<IDims>...)
{
return GetStrides(Sequence<IDim, IDims...>{});
}
__host__ __device__ static constexpr auto GetLengths()
{
return GetLengths(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
}
__host__ __device__ static constexpr auto GetStrides()
{
return GetStrides(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
}
__host__ __device__ static constexpr index_t GetElementSize()
{
return reduce_on_sequence(GetLengths(), math::multiplies<index_t>{}, Number<1>{});
}
__host__ __device__ static constexpr index_t GetElementSpace()
{
return reduce_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
}
// TODO: this cannot return constepxr because of use of lambda
__host__ __device__ static constexpr index_t CalculateOffset(const Index& idx)
{
index_t offset = 0;
static_for<0, nDim, 1>{}([&](auto idim) { offset += idx[idim] * GetStride(idim); });
return offset;
}
__host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff)
{
index_t offset_diff = 0;
static_for<0, nDim, 1>{}(
[&](auto idim) { offset_diff += idx_diff[idim] * GetStride(idim); });
return offset_diff;
}
template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
{
return true;
}
__host__ __device__ static constexpr auto GetLinearDimensionMask()
{
return typename uniform_sequence_gen<nDim, 1>::type{};
}
__host__ __device__ static constexpr auto GetNonLinearDimensionMask()
{
return typename uniform_sequence_gen<nDim, 0>::type{};
}
__host__ __device__ static constexpr auto GetNonLinearDimensions() { return Sequence<>{}; }
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
{
return Tuple<>{};
}
// a multi-index is valid if there is a corresponding point for it in the tensor
__host__ __device__ static constexpr bool IsUpperIndexValid(const Index& idx)
{
bool flag = true;
for(index_t i = 0; i < nDim; ++i)
{
flag = flag && idx[i] >= 0 && idx[i] < GetLengths()[i];
}
return flag;
}
};
// Tensor descriptor for "transformed tensor"
template <typename LowTensorDescriptor, // NativeTensorDescriptor or TransformedTensorDescriptor
typename Transforms, // Tuple<MultIndexTransforms...>
typename LowDimensionIds, // Tuple<Sequence<...>>
typename UpDimensionIds> // Tuple<Sequence<...>>
struct TransformedTensorDescriptor
{
using type = TransformedTensorDescriptor;
static constexpr index_t nTransform = Transforms::Size();
struct lambda_merge_sequences
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
{
return merge_sequences(seqs...);
}
};
__host__ __device__ static constexpr auto GetNumOfLowerDimension()
{
// Here, we assume all lower-dimensions are active
// TODO: sanity-check all lower-dimension are indeed active
using duplicated_low_active_dims =
decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{}));
using low_active_dims = typename sequence_unique_sort<duplicated_low_active_dims,
math::less<index_t>,
math::equal<index_t>>::type;
return low_active_dims::Size();
}
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
{
using duplicated_up_active_dims =
decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{}));
using up_active_dims = typename sequence_unique_sort<duplicated_up_active_dims,
math::less<index_t>,
math::equal<index_t>>::type;
return up_active_dims::Size();
}
static constexpr index_t nDimUp = GetNumOfUpperDimension();
static constexpr index_t nDimLow = GetNumOfLowerDimension();
using UpperIndex = MultiIndex<nDimUp>;
using LowerIndex = MultiIndex<nDimLow>;
__host__ __device__ constexpr TransformedTensorDescriptor()
{
static_assert(nTransform == Transforms::Size() && nTransform == LowDimensionIds::Size() &&
nTransform == UpDimensionIds::Size(),
"wrong! # of transformations not the same");
// sanity check:
// LowDimensionIds should include all low-dimensions,
// UpDimensionIds should include all up-dimensions
using mingled_up_dimension_ids =
decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{}));
using sorted_up_dimension_ids =
typename sequence_sort<mingled_up_dimension_ids, math::less<index_t>>::type;
static_assert(sorted_up_dimension_ids::Size() == nDimUp &&
is_valid_sequence_map<sorted_up_dimension_ids>{},
"wrong! UpDimensionIds is not configured correctly");
using mingled_low_dimension_ids =
decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{}));
using sorted_low_dimension_ids =
typename sequence_sort<mingled_low_dimension_ids, math::less<index_t>>::type;
static_assert(sorted_low_dimension_ids::Size() == nDimLow &&
is_valid_sequence_map<sorted_low_dimension_ids>{},
"wrong! LowDimensionIds is not configured correctly");
// TODO: sanity check: while a up-dimension could be associated with multille
// transformation, a low-dimension should be associated with only one transformation
// TODO: sanity-check: GetLowerLengths of each transform should be consistent with lengths
// of lower-tensor-descriptor
}
__host__ __device__ static constexpr auto GetNumOfDimension()
{
return GetNumOfUpperDimension();
}
__host__ __device__ static constexpr auto GetLowerTensorDescriptor()
{
return LowTensorDescriptor{};
}
struct lambda_GetUpperLengths
{
template <typename Transform>
__host__ __device__ constexpr auto operator()(const Transform& tran) const
{
return tran.GetUpperLengths();
}
};
__host__ __device__ static constexpr auto GetUpperLengths()
{
constexpr auto tuple_of_up_lengths =
transform_tuples(lambda_GetUpperLengths{}, Transforms{});
constexpr auto mingled_up_lengths = unpack(lambda_merge_sequences{}, tuple_of_up_lengths);
constexpr auto mingled_up_dimension_ids =
unpack(lambda_merge_sequences{}, UpDimensionIds{});
// TODO: sanity-check mingled_up_dimension_ids contain all upper-dimensions
// TODO: sanity-check mingled_up_lengths have no conflicting upper-length
// sort by upper-dimension-ids
using sort_up_dimension_ids = sequence_unique_sort<decltype(mingled_up_dimension_ids),
math::less<index_t>,
math::equal<index_t>>;
// sanity-check sorted-upper-dimension-ids should be Sequence<0, 1, ... nDimUp-1>
static_assert(is_same<typename sort_up_dimension_ids::type,
typename arithmetic_sequence_gen<0, nDimUp, 1>::type>{},
"wrong! UpDimensionIds is not configured correctly");
constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{};
constexpr auto sorted_up_lengths =
pick_sequence_elements_by_ids(mingled_up_lengths, sorted2unsorted_map);
return sorted_up_lengths;
}
__host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); }
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{
return GetLengths()[IDim];
}
template <index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
{
return Sequence<GetLength(Number<IDims>{})...>{};
}
template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
{
return GetLengths(Sequence<IDim, IDims...>{});
}
__host__ __device__ static constexpr index_t GetElementSize()
{
return reduce_on_sequence(GetLengths(), math::multiplies<index_t>{}, Number<1>{});
}
__host__ __device__ static constexpr index_t GetElementSpace()
{
// TODO: Is this the correct definition for transformed tensor?
return GetLowerTensorDescriptor().GetElementSpace();
}
// TODO: right now return value is not constexpr because use of non-constexpr lambda
__host__ __device__ static constexpr LowerIndex CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_part = pick_container_element(idx_up, UpDimensionIds{}.At(itran));
auto idx_low_part = pick_container_element(idx_low, LowDimensionIds{}.At(itran));
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_part = tran.CalculateLowerIndex(to_multi_index(idx_up_part));
});
return idx_low;
}
// TODO: right now return value is not constexpr because use of non-constepxr lambda
__host__ __device__ static constexpr LowerIndex CalculateLowerIndexDiff(
const UpperIndex& idx_up_diff, const UpperIndex& idx_up_old, const LowerIndex& idx_low_old)
{
LowerIndex idx_low_diff;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_diff_part =
pick_container_element(idx_up_diff, UpDimensionIds{}.At(itran));
const auto idx_up_old_part =
pick_container_element(idx_up_old, UpDimensionIds{}.At(itran));
const auto idx_low_old_part =
pick_container_element(idx_low_old, LowDimensionIds{}.At(itran));
auto idx_low_diff_part =
pick_container_element(idx_low_diff, LowDimensionIds{}.At(itran));
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_diff_part = tran.CalculateLowerIndexDiff(to_multi_index(idx_up_diff_part),
to_multi_index(idx_up_old_part),
to_multi_index(idx_low_old_part));
});
return idx_low_diff;
}
__host__ __device__ static constexpr index_t CalculateOffset(const UpperIndex& idx_up)
{
return GetLowerTensorDescriptor().CalculateOffset(CalculateLowerIndex(idx_up));
}
struct lambda_sequence_logical_and
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs...) const
{
return typename sequence_reduce<logical_and<index_t>, Seqs...>::type{};
}
};
template <typename T>
struct lambda_is_true
{
__host__ __device__ constexpr auto operator()(const T& x) const
{
// TODO: remove static_cast once Sequence can take bool as entries
return static_cast<bool>(x) == true;
}
};
struct lambda_get_linear_dimension_mask_of_single_tranform
{
// check only one transform at a time
template <typename Transform, typename LowDimensionId, typename UpDimensionId>
__host__ __device__ constexpr auto
operator()(Transform, LowDimensionId, UpDimensionId) const
{
// judge if transformation is linear
constexpr bool is_linear_transform = Transform::IsLinearTransform();
// judge if all lower dimension are linear
constexpr bool are_all_low_dim_linear = sequence_all_of(
pick_sequence_elements_by_ids(GetLowerTensorDescriptor().GetLinearDimensionMask(),
LowDimensionId{}),
lambda_is_true<index_t>{});
// create linear mask for upper dimensions
constexpr bool are_up_dim_linear = is_linear_transform && are_all_low_dim_linear;
constexpr auto mask_of_up_linear_dims = modify_sequence_elements_by_ids(
typename uniform_sequence_gen<nDimUp, 1>::type{},
typename uniform_sequence_gen<UpDimensionId::Size(), are_up_dim_linear>::type{},
UpDimensionId{});
return mask_of_up_linear_dims;
}
};
// TODO: this is a hack, transform_tuples() doesn't compile, would complain about constexpr
template <typename F, typename X, typename Y, typename Z, index_t... Is>
__host__ __device__ static constexpr auto
dummy_transform_tuples_impl(F f, X x, Y y, Z z, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
}
__host__ __device__ static constexpr auto GetLinearDimensionMask()
{
#if 0
// create tuple of linear dimension masks, for all transformations
// TODO: this doesn't compile, because transform_tuples() complain about constexpr
constexpr auto tuple_of_linear_dimension_mask =
transform_tuples(lambda_get_linear_dimension_mask_of_single_tranform{},
Transforms{},
LowDimensionIds{},
UpDimensionIds{});
#else
// create tuple of linear dimension masks, for all transformations
// TODO: this is a hack
constexpr auto tuple_of_linear_dimension_mask = dummy_transform_tuples_impl(
lambda_get_linear_dimension_mask_of_single_tranform{},
Transforms{},
LowDimensionIds{},
UpDimensionIds{},
typename arithmetic_sequence_gen<0, Transforms::Size(), 1>::type{});
#endif
// reduce tuple of masks into one mask
constexpr auto linear_dimension_mask =
unpack(lambda_sequence_logical_and{}, tuple_of_linear_dimension_mask);
return linear_dimension_mask;
}
__host__ __device__ static constexpr auto GetNonLinearDimensionMask()
{
return GetLinearDimensionMask().Transform(logical_not<index_t>{});
}
template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
{
return GetLinearDimensionMask().At(Number<IDim>{});
}
__host__ __device__ static constexpr auto GetLinearDimensions()
{
constexpr auto linear_dimension_mask = GetLinearDimensionMask();
return pick_sequence_elements_by_mask(
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, linear_dimension_mask);
}
__host__ __device__ static constexpr auto GetNonLinearDimensions()
{
constexpr auto nonlinear_dimension_mask = GetNonLinearDimensionMask();
return pick_sequence_elements_by_mask(
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, nonlinear_dimension_mask);
}
#if 0
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
{
// TODO: not implemented
}
#endif
// a multi-index is valid if there is a corresponding point for it in the tensor
__host__ __device__ constexpr bool IsUpperIndexValid(const UpperIndex& idx_up) const
{
bool flag = true;
for(index_t i = 0; i < nDimUp; ++i)
{
flag = flag && idx_up[i] >= 0 && idx_up[i] < GetLengths()[i];
}
return flag;
}
// this function is for optimization purpose, it's called by tensor coordinate
// this function tells you: If a lower-index is valid or not, assuming upper index is valid
__host__ __device__ static constexpr bool
IsLowerIndexValidAssumingUpperIndexIsValid(const LowerIndex& idx_low)
{
bool flag = true;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran);
// check a indtransformation if it does not always has a valid mapping
constexpr bool is_valid_up_always_mapped_to_valid_low =
decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex();
if(!is_valid_up_always_mapped_to_valid_low)
{
constexpr auto low_dims_part = LowDimensionIds{}.At(itran);
constexpr auto low_lengths_part =
GetLowerTensorDescriptor().GetLengths(low_dims_part);
const auto idx_low_part =
to_multi_index(pick_container_element(idx_low, low_dims_part));
static_for<0, decltype(low_dims_part)::Size(), 1>{}([&](auto i) {
flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < low_lengths_part[i];
});
}
});
return flag;
}
};
} // namespace ck
#endif
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
namespace ck {
template <typename Lengths>
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
{
return reverse_inclusive_scan_sequence(
Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
}
template <typename Lengths, index_t Align>
__host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number<Align>)
{
constexpr index_t L_back_align =
Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
return calculate_tensor_strides_packed(
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
}
template <index_t... Lengths, index_t... Strides>
__host__ __device__ constexpr auto make_native_tensor_descriptor(Sequence<Lengths...>,
Sequence<Strides...>)
{
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
}
template <typename Lengths>
__host__ __device__ constexpr auto make_native_tensor_descriptor_packed(Lengths)
{
constexpr auto strides = calculate_tensor_strides_packed(Lengths{});
return make_native_tensor_descriptor(Lengths{}, strides);
}
template <typename Lengths, index_t Align>
__host__ __device__ constexpr auto make_native_tensor_descriptor_aligned(Lengths, Number<Align>)
{
constexpr auto strides = calculate_tensor_strides_aligned(Lengths{}, Number<Align>{});
return make_native_tensor_descriptor(Lengths{}, strides);
}
template <typename LowTensorDescriptor,
typename Transforms,
typename LowDimensionIds,
typename UpDimensionIds>
__host__ __device__ constexpr auto
transform_tensor_descriptor(LowTensorDescriptor, Transforms, LowDimensionIds, UpDimensionIds)
{
return TransformedTensorDescriptor<LowTensorDescriptor,
Transforms,
LowDimensionIds,
UpDimensionIds>{};
}
template <typename LowerTensorDescriptor,
index_t... LowerLengths,
index_t... LowerDimensionIds,
index_t... UpperDimensionIds>
__host__ __device__ constexpr auto
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
Sequence<LowerLengths...>,
Sequence<LowerDimensionIds...>,
Sequence<UpperDimensionIds...>)
{
return TransformedTensorDescriptor<LowerTensorDescriptor,
Tuple<PassThrough<LowerLengths>...>,
Tuple<Sequence<LowerDimensionIds>...>,
Tuple<Sequence<UpperDimensionIds>...>>{};
}
// reorder a NativeTensorDescriptor
template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
{
static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map");
constexpr auto old_desc = NativeTensorDescriptor<Ts...>{};
static_assert(old_desc.GetNumOfDimension() == MapLower2Upper::Size(), "wrong!");
constexpr auto new_lengths = old_desc.GetLengths().ReorderGivenOld2New(MapLower2Upper{});
constexpr auto new_strides = old_desc.GetStrides().ReorderGivenOld2New(MapLower2Upper{});
return make_native_tensor_descriptor(new_lengths, new_strides);
}
// reorder a TransformedTensorDescriptor
template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
{
static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map");
constexpr auto low_desc = TransformedTensorDescriptor<Ts...>{};
static_assert(low_desc.GetNumOfDimension() == MapLower2Upper::Size(), "wrong!");
return reorder_transformed_tensor_descriptor_impl(
low_desc,
low_desc.GetLengths(),
typename arithmetic_sequence_gen<0, low_desc.GetNumOfDimension(), 1>::type{},
MapLower2Upper{});
}
template <typename LowerTensorDescriptor, typename MapUpper2Lower>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_upper2lower(LowerTensorDescriptor, MapUpper2Lower)
{
return reorder_tensor_descriptor_given_lower2upper(
LowerTensorDescriptor{}, typename sequence_map_inverse<MapUpper2Lower>::type{});
}
template <typename Lengths, typename Strides>
__host__ __device__ constexpr bool are_dimensions_unfoldable(Lengths, Strides)
{
static_assert(Lengths::Size() == Strides::Size(), "wrong!");
bool flag = true;
for(index_t i = 0; i < Lengths::Size() - 1; ++i)
{
flag = flag && Strides::At(i) == Strides::At(i + 1) * Lengths::At(i + 1);
}
return flag;
}
// unfold only support NativeTennsorDescriptor, for now
template <index_t FirstUnfoldDim, index_t LastUnfoldDim, typename... Ts>
__host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescriptor<Ts...> desc,
Number<FirstUnfoldDim>,
Number<LastUnfoldDim>)
{
constexpr index_t nDim = desc.GetNumOfDimension();
static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim && FirstUnfoldDim <= LastUnfoldDim,
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!");
// left and right
constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::type{};
constexpr auto middle =
typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::type{};
constexpr auto right = typename arithmetic_sequence_gen<LastUnfoldDim + 1, nDim, 1>::type{};
// sanity-check if unfold-able
static_assert(are_dimensions_unfoldable(desc.GetLengths(middle), desc.GetStrides(middle)),
"wrong! not unfold-able");
// unfolded length, stride
constexpr index_t unfold_length =
reduce_on_sequence(desc.GetLengths(middle), math::multiplies<index_t>{}, Number<1>{});
constexpr index_t unfold_stride = desc.GetStride(Number<LastUnfoldDim>{});
// new lengths, strides
constexpr auto new_lengths =
desc.GetLengths(left).PushBack(Number<unfold_length>{}).PushBack(desc.GetLengths(right));
constexpr auto new_strides =
desc.GetStrides(left).PushBack(Number<unfold_stride>{}).PushBack(desc.GetStrides(right));
return make_native_tensor_descriptor(new_lengths, new_strides);
}
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_BATCHED_GEMM_HPP
#define CK_BLOCKWISE_BATCHED_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp"
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 1
#endif
namespace ck {
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
index_t BlockMatrixStrideA,
index_t BlockMatrixStrideB,
index_t ThreadMatrixStrideC,
index_t BatchSize,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop,
index_t BatchPerThread,
index_t DataPerReadA,
index_t DataPerReadB>
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
index_t mMyThreadOffsetA = 0;
index_t mMyThreadOffsetB = 0;
struct MatrixIndex
{
index_t batch;
index_t row;
index_t col;
};
__device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2()
{
static_assert(BatchSize % BatchPerThread == 0,
"wrong! BatchSize is not dividable by BatchPerThread");
constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster,
"wrong! wrong blocksize\n");
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n");
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
"wrong! Cannot evenly divide thread work among repeat \n");
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
"wrong! Cannot evenly divide work among repeat\n");
constexpr index_t MPerLevel1Cluster = M / MRepeat;
constexpr index_t NPerLevel1Cluster = N / NRepeat;
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
(NPerLevel1Cluster % NLevel1Cluster == 0),
"wrong! Cannot evenly divide work among Level1Cluster\n");
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
(NPerLevel0Cluster % NLevel0Cluster == 0),
"wrong! Cannot evenly divide work among Level0Cluster\n");
static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
(NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
"wrong! thread work size is wrong\n");
const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA +
a_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.row);
mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB +
b_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.col);
}
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
index_t batch_work_id = thread_id / ThreadPerLevel1Cluster;
index_t cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
index_t level1_id = cluster_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1Cluster;
index_t level1_n_id = level1_id % NLevel1Cluster;
index_t level0_id = cluster_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0Cluster;
index_t level0_n_id = level0_id % NLevel0Cluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{batch_work_id * BatchPerThread,
level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
// this should be optimized away because input will be known at compile time
__device__ static MatrixIndex
GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC;
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{batch_in_c,
m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c};
}
template <class FloatA, class FloatB, class FloatC>
__device__ void Run_source(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// loop over k
#pragma unroll
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// loop over batch
#pragma unroll
for(index_t ib = 0; ib < BatchPerThread; ++ib)
{
// read next batch of a, b
if(BlockMatrixStrideA != 0 or ib == 0)
{
#pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(a_block_mtx,
p_a_block +
a_block_mtx.GetOffsetFromMultiIndex(
k_begin, m_repeat * MPerLevel1Cluster) +
ib * BlockMatrixStrideA + mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.GetOffsetFromMultiIndex(
0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
}
if(BlockMatrixStrideB != 0 or ib == 0)
{
#pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(b_block_mtx,
p_b_block +
b_block_mtx.GetOffsetFromMultiIndex(
k_begin, n_repeat * NPerLevel1Cluster) +
ib * BlockMatrixStrideB + mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.GetOffsetFromMultiIndex(
0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
}
threadwise_gemm(a_thread_mtx,
True,
p_a_thread,
b_thread_mtx,
False,
p_b_thread,
c_thread_mtx,
False,
p_c_thread + ib * ThreadMatrixStrideC);
}
}
}
#if CK_USE_AMD_INLINE_ASM
template <class FloatA, class FloatB, class FloatC>
__device__ void Run_amd_asm(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t K = a_block_mtx.NRow(); // A is transposed
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
// assertion for inline asm
static_assert(is_same<FloatA, float>{} && is_same<FloatB, float>{} &&
is_same<FloatC, float>{},
"Run_amd_asm only deal with float\n");
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
MPerThread == 8 && NPerThread == 8,
"Run_amd_asm cannot deal with this GEMM shape yet\n");
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_amd_asm only do float4 read\n");
static_assert(BlockMatrixStrideA == 0 && BatchPerThread == 1,
"Run_amd_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == "
"1 for now\n");
using Float4 = vector_type<float, 4>::type;
Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread);
Float4* reg_c = (Float4*)(p_c_thread);
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] = *reinterpret_cast<const Float4*>(
&p_b_block[b_block_mtx.GetOffsetFromMultiIndex(0, NPerLevel1Cluster) +
mMyThreadOffsetB]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[a_block_mtx.GetOffsetFromMultiIndex(0, MPerLevel1Cluster) +
mMyThreadOffsetA]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
reg_a[0] = *reinterpret_cast<const Float4*>(
&p_a_block[a_block_mtx.GetOffsetFromMultiIndex(k, 0) + mMyThreadOffsetA]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
reg_b[0] = *reinterpret_cast<const Float4*>(
&p_b_block[b_block_mtx.GetOffsetFromMultiIndex(k, 0) + mMyThreadOffsetB]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
reg_b[1] = *reinterpret_cast<const Float4*>(
&p_b_block[b_block_mtx.GetOffsetFromMultiIndex(k, NPerLevel1Cluster) +
mMyThreadOffsetB]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[a_block_mtx.GetOffsetFromMultiIndex(k, MPerLevel1Cluster) +
mMyThreadOffsetA]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
}
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
}
#endif
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
Run_amd_asm(p_a_block, p_b_block, p_c_thread);
#else
Run_source(p_a_block, p_b_block, p_c_thread);
#endif
}
template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
FloatC* __restrict__ p_c_block) const
{
constexpr auto c_block_mtx = BlockMatrixC{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t c_thread_offset =
c_thread_mtx_begin.batch * BlockMatrixStrideC +
c_block_mtx.GetOffsetFromMultiIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col);
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
c_thread_sub_mtx,
p_c_thread + c_thread_sub_mtx.GetOffsetFromMultiIndex(
m_repeat * MPerLevel1Cluster, n_repeat * NPerLevel1Cluster),
c_block_mtx,
p_c_block +
c_block_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster,
n_repeat * NPerLevel1Cluster) +
c_thread_offset,
c_thread_sub_mtx.GetLengths());
}
}
}
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_GEMM_HPP
#define CK_BLOCKWISE_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp"
namespace ck {
// blockwise GEMM: C += transpose(A) * B
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t KPerThreadLoop,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
struct MatrixIndex
{
index_t row;
index_t col;
};
index_t mMyThreadOffsetA;
index_t mMyThreadOffsetB;
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
{
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
"wrong! Cannot evenly divide work among\n");
static_assert(
is_same<decltype(ThreadMatrixC::GetLengths()), decltype(GetThreadMatrixCLengths())>{},
"wrong! ThreadMatrixC lengths is wrong");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = BlockMatrixA::GetOffsetFromMultiIndex(0, c_thread_mtx_index.row);
mMyThreadOffsetB = BlockMatrixB::GetOffsetFromMultiIndex(0, c_thread_mtx_index.col);
}
__device__ static constexpr auto GetThreadMatrixCLengths()
{
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t MRepeat =
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
constexpr index_t NRepeat =
N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster);
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster;
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1ThreadCluster;
index_t level1_n_id = level1_id % NLevel1ThreadCluster;
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0ThreadCluster;
index_t level0_n_id = level0_id % NLevel0ThreadCluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t K = a_block_mtx.NRow();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// thread A, B for GEMM
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixA,
decltype(a_thread_mtx),
KPerThreadLoop,
MPerThreadSubC,
ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
decltype(b_thread_mtx),
KPerThreadLoop,
NPerThreadSubC,
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm =
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_mtx),
decltype(b_thread_mtx),
decltype(c_thread_mtx)>{};
#pragma unroll
// loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#pragma unroll
// read A
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
a_thread_copy.Run(
p_a_block + a_block_mtx.CalculateOffset(k_begin, m_repeat * MPerLevel1Cluster) +
mMyThreadOffsetA,
p_a_thread + a_thread_mtx.CalculateOffset(0, m_repeat * MPerThreadSubC));
}
#pragma unroll
// read B
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
b_thread_copy.Run(
p_b_block + b_block_mtx.CalculateOffset(k_begin, n_repeat * NPerLevel1Cluster) +
mMyThreadOffsetB,
p_b_thread + b_thread_mtx.CalculateOffset(0, n_repeat * NPerThreadSubC));
}
// C += A * B
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
}
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t K = a_block_mtx.NRow();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert(MRepeat == 2 && NRepeat == 2,
"wrong! inline asm cannot deal with this GEMM config yet");
// thread A, B
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = a_thread_mtx.MakeSubMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{});
constexpr auto b_thread_sub_mtx = b_thread_mtx.MakeSubMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{});
// thread C-sub
constexpr auto c_thread_sub_mtx = ThreadMatrixC::MakeSubMatrixDescriptor(
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixA,
decltype(a_thread_mtx),
KPerThreadLoop,
MPerThreadSubC,
ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
decltype(b_thread_mtx),
KPerThreadLoop,
NPerThreadSubC,
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm =
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA;
const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB;
// read A_sub_0
a_thread_copy.Run(p_a_block_off, p_a_thread);
// read B_sub_0
b_thread_copy.Run(p_b_block_off, p_b_thread);
// read B_sub_1
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(0, NPerLevel1Cluster),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC));
// read A_sub_1
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(0, MPerLevel1Cluster),
p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC));
#pragma unroll
// loop over rest of k
for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop)
{
// read A_sub_0
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, 0), p_a_thread);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
p_b_thread,
p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0));
// read B_sub_0
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, 0), p_b_thread);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
p_c_thread +
ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC));
// read B_sub_1
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, NPerLevel1Cluster),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC));
// read A_sub_1
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, MPerLevel1Cluster),
p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC));
}
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
p_b_thread,
p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
p_c_thread +
ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC));
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr index_t MPerThread = ThreadMatrixC::NRow();
constexpr index_t NPerThread = ThreadMatrixC::NCol();
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
if constexpr(MRepeat == 2 && NRepeat == 2)
{
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
}
else
{
Run_naive(p_a_block, p_b_block, p_c_thread);
}
#else
Run_naive(p_a_block, p_b_block, p_c_thread);
#endif
}
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#define CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_coordinate.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace ck {
// This blockwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// The dimension of vector access can be different for src and dst.
// The dimension access order can be different for src and dst.
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
// BlockSize can be equal or larger than ThreadCluster size, which means some threads may not do
// threadwise copy
template <index_t BlockSize,
typename BlockSrcDesc,
typename BlockDstDesc,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectoReadDim,
index_t DstVectorWriteDim,
index_t SrcDataPerRead,
index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::Generic,
AddressSpace ThreadBufferAddressSpace = AddressSpace::Generic,
AddressSpace DstAddressSpace = AddressSpace::Generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set,
index_t SrcDataStride = 1,
index_t DstDataStride = 1>
struct BlockwiseGenericTensorSliceCopy_v4
{
static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseGenericTensorSliceCopy_v4(const Index& src_block_slice_origin,
const Index& dst_block_slice_origin)
{
static_assert(nDim == BlockSrcDesc::GetNumOfDimension() &&
nDim == BlockDstDesc::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= mThreadClusterDesc.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
const auto thread_cluster_id =
mThreadClusterDesc.CalculateClusterIndex(get_thread_local_1d_id());
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
mThreadwiseLoad.SetDstSliceOrigin(make_zero_multi_index<nDim>());
mThreadwiseStore.SetSrcSliceOrigin(make_zero_multi_index<nDim>());
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
}
}
__device__ static constexpr index_t GetThreadBufferSize()
{
return ThreadBufferDesc::GetElementSpace();
}
template <typename BlockSrcData, typename ThreadBufferData>
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
ThreadBufferData* p_thread_buffer) const
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
}
}
template <typename ThreadBufferData, typename BlockDstData>
__device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
BlockDstData* p_block_dst) const
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
}
}
template <typename BlockSrcData, typename BlockDstData>
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
{
static_assert(ThreadBufferAddressSpace == AddressSpace::Vgpr,
"wrong! This function use vgpr as its thread "
"buffer. However, you have set RunLoadThreadBuffer and RunStoreThreadBuffer "
"to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. "
"Behavior may be different");
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
RunLoadThreadBuffer(p_block_src, p_thread_buffer);
// if there is type conversion, it's done during store
RunStoreThreadBuffer(p_thread_buffer, p_block_dst);
}
}
template <typename T, bool PositiveDirection>
__device__ void
MoveSrcSliceWindow(const T& step_sizes,
integral_constant<bool, PositiveDirection> positive_direction)
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction);
}
}
template <typename T, bool PositiveDirection>
__device__ void
MoveDstSliceWindow(const T& step_sizes,
integral_constant<bool, PositiveDirection> positive_direction)
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
}
}
private:
using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(ThreadSliceLengths{}));
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v4r2<BlockSrcDesc,
ThreadBufferDesc,
ThreadSliceLengths,
SrcDimAccessOrder,
SrcVectoReadDim,
SrcDataPerRead,
1,
SrcAddressSpace,
ThreadBufferAddressSpace,
InMemoryDataOperation::Set,
SrcDataStride,
1>;
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc,
BlockDstDesc,
ThreadSliceLengths,
DstDimAccessOrder,
DstVectorWriteDim,
1,
DstDataPerWrite,
ThreadBufferAddressSpace,
DstAddressSpace,
DstInMemOp,
1,
DstDataStride>;
static constexpr auto mThreadClusterDesc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
ThreadwiseLoad mThreadwiseLoad;
ThreadwiseStore mThreadwiseStore;
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_GEMM_HPP
#define CK_GRIDWISE_GEMM_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
InMemoryDataOperation CGlobalMemoryDataOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t ThreadGemmAThreadCopySrcDataPerRead_M,
index_t ThreadGemmBThreadCopySrcDataPerRead_N,
typename ABlockCopyThreadSliceLengths_K_M,
typename ABlockCopyThreadClusterLengths_K_M,
typename ABlockCopyThreadClusterArrangeOrder,
typename ABlockCopySrcAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_M,
typename BBlockCopyThreadSliceLengths_K_N,
typename BBlockCopyThreadClusterLengths_K_N,
typename BBlockCopyThreadClusterArrangeOrder,
typename BBlockCopySrcAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_N,
typename CThreadCopySrcDstAccessOrder,
index_t CThreadCopySrcDstVectorReadWriteDim,
index_t CThreadCopyDstDataPerWrite>
struct GridwiseGemmTransposedANormalBNormalC_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
}
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto a_k_m_global_desc = AGlobalDesc{};
constexpr auto b_k_n_global_desc = BGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto K = a_k_m_global_desc.GetLengths()[0];
constexpr auto M = a_k_m_global_desc.GetLengths()[1];
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
// don't do anything if K == 0
if(K == 0)
{
return;
}
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N);
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t m_block_data_on_global = block_work_id[Number<0>{}] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[Number<1>{}] * NPerBlock;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
decltype(a_k_m_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_K_M,
ABlockCopyThreadClusterLengths_K_M,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
Sequence<0, 1>,
ABlockCopySrcVectorReadDim,
1,
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_M,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
make_multi_index(0, m_block_data_on_global), make_multi_index(0, 0));
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
decltype(b_k_n_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_K_N,
BBlockCopyThreadClusterLengths_K_N,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
Sequence<0, 1>,
BBlockCopySrcVectorReadDim,
1,
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_N,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
make_multi_index(0, n_block_data_on_global), make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc);
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
// sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * MPerThread>{}, Number<GemmNRepeat * NPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThread,
NPerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
constexpr auto a_block_slice_copy_step = Sequence<KPerBlock, 0>{};
constexpr auto b_block_slice_copy_step = Sequence<KPerBlock, 0>{};
Float* p_a_block_even = p_a_block_double;
Float* p_b_block_even = p_b_block_double;
Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock;
k_block_data_begin += 2 * KPerBlock)
{
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_odd);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_odd);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_even);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_even);
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space_size);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space_size);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
// input: register to global memory
{
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
constexpr index_t M0 = M / M1;
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
constexpr index_t N0 = N / N1;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, MPerThread, GemmNRepeat, NPerThread>{});
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
make_tuple(UnMerge<Sequence<M0, M1>>{}, UnMerge<Sequence<N0, N1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
CThreadCopySrcDstAccessOrder,
CThreadCopySrcDstVectorReadWriteDim,
1,
CThreadCopyDstDataPerWrite,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
make_multi_index(0, 0, 0, 0),
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(p_c_thread, p_c_global);
}
}
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_block[shared_block_size];
Run(p_a_global, p_b_global, p_c_global, p_shared_block);
}
};
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
InMemoryDataOperation CGlobalMemoryDataOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t ThreadGemmAThreadCopySrcDataPerRead_M,
index_t ThreadGemmBThreadCopySrcDataPerRead_N,
typename ABlockCopyThreadSliceLengths_K0_K1_K2_M,
typename ABlockCopyThreadClusterLengths_K0_K1_K2_M,
typename ABlockCopyThreadClusterArrangeOrder,
typename ABlockCopySrcAccessOrder,
index_t ABlockCopySrcVectorReadDim,
index_t ABlockCopySrcDataPerRead,
index_t ABlockCopyDstDataPerWrite_M,
typename BBlockCopyThreadSliceLengths_K0_K1_K2_N,
typename BBlockCopyThreadClusterLengths_K0_K1_K2_N,
typename BBlockCopyThreadClusterArrangeOrder,
typename BBlockCopySrcAccessOrder,
index_t BBlockCopySrcVectorReadDim,
index_t BBlockCopySrcDataPerRead,
index_t BBlockCopyDstDataPerWrite_N,
typename CThreadCopySrcDstAccessOrder,
index_t CThreadCopySrcDstVectorReadWriteDim,
index_t CThreadCopyDstDataPerWrite>
struct GridwiseGemmTransposedANormalBNormalC_v2
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
}
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_k0_k1_k2_m_global_desc = AGlobalDesc{};
constexpr auto b_k0_k1_k2_n_global_desc = BGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto K0 = a_k0_k1_k2_m_global_desc.GetLengths()[I0];
constexpr auto K1 = a_k0_k1_k2_m_global_desc.GetLengths()[I1];
constexpr auto K = a_k0_k1_k2_m_global_desc.GetLengths()[I2];
constexpr auto M = c_m_n_global_desc.GetLengths()[I0];
constexpr auto N = c_m_n_global_desc.GetLengths()[I1];
// don't do anything if K == 0
if(K == 0)
{
return;
}
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N);
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t m_block_data_on_global = block_work_id[I0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[I1] * NPerBlock;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_k1_k2_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, 1, KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(a_k0_k1_k2_m_global_desc),
decltype(a_k0_k1_k2_m_block_desc),
decltype(a_k0_k1_k2_m_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_K0_K1_K2_M,
ABlockCopyThreadClusterLengths_K0_K1_K2_M,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
Sequence<0, 1, 2, 3>,
ABlockCopySrcVectorReadDim,
3,
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_M,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
make_multi_index(0, 0, 0, m_block_data_on_global), make_multi_index(0, 0, 0, 0));
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_k1_k2_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, 1, KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(b_k0_k1_k2_n_global_desc),
decltype(b_k0_k1_k2_n_block_desc),
decltype(b_k0_k1_k2_n_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_K0_K1_K2_N,
BBlockCopyThreadClusterLengths_K0_K1_K2_N,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
Sequence<0, 1, 2, 3>,
BBlockCopySrcVectorReadDim,
3,
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_N,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
make_multi_index(0, 0, 0, n_block_data_on_global), make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(
unfold_tensor_descriptor(a_k0_k1_k2_m_block_desc, I0, I2));
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(
unfold_tensor_descriptor(b_k0_k1_k2_n_block_desc, I0, I2));
// sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * MPerThread>{}, Number<GemmNRepeat * NPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThread,
NPerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k0_k1_k2_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k0_k1_k2_n_block_desc.GetElementSpace(), max_lds_align);
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
for(index_t k0 = 0; k0 < K0; ++k0)
{
for(index_t k1 = 0; k1 < K1; ++k1)
{
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
constexpr auto a_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{};
constexpr auto b_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{};
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space_size;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space_size;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space_size : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space_size : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(
p_a_thread_buffer, p_a_block_double + a_block_space_size);
b_blockwise_copy.RunStoreThreadBuffer(
p_b_thread_buffer, p_b_block_double + b_block_space_size);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
// reset slice windoww on K2 dimension, then move forward on K1 dimension
a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 0, K - KPerBlock, 0>{}, False);
b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 0, K - KPerBlock, 0>{}, False);
a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 1, 0, 0>{}, True);
b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 1, 0, 0>{}, True);
}
// reset slice windoww on K1 dimension, then move forward on K0 dimension
a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, K1, 0, 0>{}, False);
b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, K1, 0, 0>{}, False);
a_blockwise_copy.MoveSrcSliceWindow(Sequence<1, 0, 0, 0>{}, True);
b_blockwise_copy.MoveSrcSliceWindow(Sequence<1, 0, 0, 0>{}, True);
}
// input: register to global memory
{
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
constexpr index_t M0 = M / M1;
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
constexpr index_t N0 = N / N1;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, MPerThread, GemmNRepeat, NPerThread>{});
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
make_tuple(UnMerge<Sequence<M0, M1>>{}, UnMerge<Sequence<N0, N1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
CThreadCopySrcDstAccessOrder,
CThreadCopySrcDstVectorReadWriteDim,
1,
CThreadCopyDstDataPerWrite,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
make_multi_index(0, 0, 0, 0),
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(p_c_thread, p_c_global);
}
}
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_block[shared_block_size];
Run(p_a_global, p_b_global, p_c_global, p_shared_block);
}
};
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment