Commit b37cb71f authored by Wen-Heng (Jack) Chung's avatar Wen-Heng (Jack) Chung
Browse files

Enable bwd wrw

parent c5143bca
#ifndef CK_CONVOLUTION_COMMON_HPP
#define CK_CONVOLUTION_COMMON_HPP
namespace ck {
enum ConvolutionDirection
{
Forward,
BackwardData,
BackwardWeight
};
} // namespace ck
#endif
...@@ -2,22 +2,21 @@ ...@@ -2,22 +2,21 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "implicitgemm_params.hpp" #include "convolution_common.hpp"
namespace ck { namespace ck {
template <ImplicitGemmDirection conv_dir, typename WeiDesc, index_t NonVectorizedC> template <ConvolutionDirection, typename WeiDesc, index_t NonVectorizedC>
struct make_vectorized_WeiDesc struct make_vectorized_WeiDesc;
{
};
template <typename WeiDesc, index_t NonVectorizedC> template <typename WeiDesc, index_t NonVectorizedC>
struct make_vectorized_WeiDesc<ImplicitGemmDirection::ForwardData, WeiDesc, NonVectorizedC> struct make_vectorized_WeiDesc<ConvolutionDirection::Forward, WeiDesc, NonVectorizedC>
{ {
__device__ constexpr auto get(WeiDesc&) __device__ constexpr auto get(WeiDesc&)
{ {
...@@ -30,8 +29,9 @@ struct make_vectorized_WeiDesc<ImplicitGemmDirection::ForwardData, WeiDesc, NonV ...@@ -30,8 +29,9 @@ struct make_vectorized_WeiDesc<ImplicitGemmDirection::ForwardData, WeiDesc, NonV
.ReorderGivenNew2Old(Sequence<2, 0, 1>{}); .ReorderGivenNew2Old(Sequence<2, 0, 1>{});
} }
}; };
template <typename WeiDesc, index_t NonVectorizedC> template <typename WeiDesc, index_t NonVectorizedC>
struct make_vectorized_WeiDesc<ImplicitGemmDirection::BackwardWeight, WeiDesc, NonVectorizedC> struct make_vectorized_WeiDesc<ConvolutionDirection::BackwardWeight, WeiDesc, NonVectorizedC>
{ {
__device__ constexpr auto get(WeiDesc& desc) __device__ constexpr auto get(WeiDesc& desc)
{ {
...@@ -56,6 +56,7 @@ template <index_t GridSize, ...@@ -56,6 +56,7 @@ template <index_t GridSize,
class OutGlobalDesc, class OutGlobalDesc,
class ConvStrides, class ConvStrides,
class ConvDilations, class ConvDilations,
ConvolutionDirection ConvDirection,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
...@@ -83,8 +84,7 @@ template <index_t GridSize, ...@@ -83,8 +84,7 @@ template <index_t GridSize,
class WeiBlockCopySrcAccessOrder, class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder, class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K, index_t WeiBlockCopyDstDataPerWrite_K>
ImplicitGemmDirection conv_dir>
struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
...@@ -198,27 +198,28 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -198,27 +198,28 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockSize,
decltype(in_e_n1_b_n2_2eor4e_global_merged_desc), decltype(in_e_n1_b_n2_2eor4e_global_merged_desc),
decltype(in_e_n1_b_n2_2eor4e_block_desc), decltype(in_e_n1_b_n2_2eor4e_block_desc),
decltype( decltype(in_e_n1_b_n2_2eor4e_block_desc.GetLengths()),
in_e_n1_b_n2_2eor4e_block_desc.GetLengths()), InBlockCopySubLengths_E_N1_B_N2_EPACK,
InBlockCopySubLengths_E_N1_B_N2_EPACK, InBlockCopyClusterLengths_E_N1_B_N2_EPACK,
InBlockCopyClusterLengths_E_N1_B_N2_EPACK, InBlockCopyThreadClusterArrangeOrder,
InBlockCopyThreadClusterArrangeOrder, InBlockCopySrcAccessOrder,
InBlockCopySrcAccessOrder, InBlockCopyDstAccessOrder,
InBlockCopyDstAccessOrder, 2,
2, 4,
4, InBlockCopySrcDataPerRead_B,
InBlockCopySrcDataPerRead_B, InBlockCopyDstDataPerWrite_EPACK>({0, 0, b_block_data_on_global, 0, 0},
InBlockCopyDstDataPerWrite_EPACK>( {0, 0, 0, 0, 0});
{0, 0, b_block_data_on_global, 0, 0}, {0, 0, 0, 0, 0});
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_2eor4e_global_desc = constexpr auto wei_e_k_2eor4e_global_desc =
make_vectorized_WeiDesc<conv_dir, decltype(wei_k_c_y_x_global_desc), nonVectorizedC>{} make_vectorized_WeiDesc<ConvDirection,
decltype(wei_k_c_y_x_global_desc),
nonVectorizedC>{}
.get(wei_k_c_y_x_global_desc); .get(wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
...@@ -235,21 +236,20 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -235,21 +236,20 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockSize,
decltype(wei_e_k_2eor4e_global_desc), decltype(wei_e_k_2eor4e_global_desc),
decltype(wei_e_k_2eor4e_block_desc), decltype(wei_e_k_2eor4e_block_desc),
decltype(wei_e_k_2eor4e_block_desc.GetLengths()), decltype(wei_e_k_2eor4e_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K_EPACK, WeiBlockCopySubLengths_E_K_EPACK,
WeiBlockCopyClusterLengths_E_K_EPACK, WeiBlockCopyClusterLengths_E_K_EPACK,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder, WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder, WeiBlockCopyDstAccessOrder,
0, 0,
2, 2,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_EPACK>( WeiBlockCopyDstDataPerWrite_EPACK>({0, k_block_data_on_global, 0}, {0, 0, 0});
{0, k_block_data_on_global, 0}, {0, 0, 0});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -279,7 +279,6 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -279,7 +279,6 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
EPACK,
decltype(a_e_k_block_mtx_desc), decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc), decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k2_n1n2_thread_mtx_desc), decltype(c_k0k2_n1n2_thread_mtx_desc),
...@@ -347,12 +346,12 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -347,12 +346,12 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto fwd) { static_if<ConvDirection == ConvolutionDirection::BackwardWeight>{}([&](auto fwd) {
fwd(blockwise_wei_copy).MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True); fwd(blockwise_wei_copy).MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0>{}, True);
}).Else([&](auto fwd) { }).Else([&](auto fwd) {
p_wei_block_on_global += p_wei_block_on_global +=
EPerBlock * fwd(wei_e_k_2eor4e_global_desc).GetStride(I0); EPerBlock * fwd(wei_e_k_2eor4e_global_desc).GetStride(I0);
...@@ -361,9 +360,8 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -361,9 +360,8 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
const typename vector_type<Float, EPACK>::MemoryType* p_a_block_vec = const typename vector_type<Float, EPACK>::MemoryType* p_a_block_vec =
...@@ -375,20 +373,20 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -375,20 +373,20 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto fwd) { static_if<ConvDirection == ConvolutionDirection::BackwardWeight>{}([&](auto fwd) {
fwd(blockwise_wei_copy).MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True); fwd(blockwise_wei_copy).MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0>{}, True);
}).Else([&](auto fwd) { }).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_2eor4e_global_desc).GetStride(I0); p_wei_block_on_global += EPerBlock * fwd(wei_e_k_2eor4e_global_desc).GetStride(I0);
}); });
...@@ -396,8 +394,8 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -396,8 +394,8 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are // Vectorize the pointer to match with how half/bfloat16 datatypes are
...@@ -415,10 +413,10 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -415,10 +413,10 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -479,7 +477,7 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -479,7 +477,7 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0); k_thread_data_on_global, 0, b_thread_data_on_global, 0);
ThreadwiseGenericTensorSliceCopy_v1r2< ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()), decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KC1x1_NKHW_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KC1x1_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "implicitgemm_params.hpp" #include "convolution_common.hpp"
namespace ck { namespace ck {
...@@ -21,7 +21,7 @@ template <index_t GridSize, ...@@ -21,7 +21,7 @@ template <index_t GridSize,
class WeiGlobalDesc, class WeiGlobalDesc,
class OutGlobalDesc, // exchanged outside for backward class OutGlobalDesc, // exchanged outside for backward
class ConvStrides, class ConvStrides,
ImplicitGemmDirection Direction, ConvolutionDirection ConvDirection,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
...@@ -56,7 +56,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer ...@@ -56,7 +56,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
constexpr bool isForward = Direction == ImplicitGemmDirection::ForwardData; constexpr bool isForward = (ConvDirection == ConvolutionDirection::Forward);
// this is a mess // this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters // TODO: find more elegent way of specifying (or calculating) performance parameters
...@@ -161,21 +161,20 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer ...@@ -161,21 +161,20 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockSize,
decltype(in_e_n1_b_n2_global_merged_desc), decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc), decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()), decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2, InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2, InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder, InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder, InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder, InBlockCopyDstAccessOrder,
2, 2,
3, 3,
InBlockCopySrcDataPerRead_B, InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>( InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
...@@ -198,19 +197,19 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer ...@@ -198,19 +197,19 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K, WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder, WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder, WeiBlockCopyDstAccessOrder,
0, 0,
1, 1,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
...@@ -239,7 +238,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer ...@@ -239,7 +238,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
1, // EPACK = 1
decltype(a_e_k_block_mtx_desc), decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc), decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k2_n1n2_thread_mtx_desc), decltype(c_k0k2_n1n2_thread_mtx_desc),
...@@ -301,51 +299,50 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer ...@@ -301,51 +299,50 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -437,7 +434,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer ...@@ -437,7 +434,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0); k_thread_data_on_global, 0, b_thread_data_on_global, 0);
ThreadwiseGenericTensorSliceCopy_v1r2< ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()), decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "convolution_common.hpp"
namespace ck {
template <ConvolutionDirection>
struct make_wei_e_k_global_desc_v4r1;
template <>
struct make_wei_e_k_global_desc_v4r1<ConvolutionDirection::Forward>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
return reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(WeiDesc{}, I1, I3), Sequence<1, 0>{});
}
};
template <>
struct make_wei_e_k_global_desc_v4r1<ConvolutionDirection::BackwardWeight>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto wei_k_c_y_x_global_desc = WeiDesc{};
constexpr index_t K = wei_k_c_y_x_global_desc.GetLength(I0);
constexpr index_t C = wei_k_c_y_x_global_desc.GetLength(I1);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
return transform_tensor_descriptor(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3),
make_tuple(Merge<Sequence<C, Y * X>>{}, PassThrough<K>{}),
make_tuple(Sequence<1, 2>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
};
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccDataType,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
ConvolutionDirection ConvDirection,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmNRepeat,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
typename InBlockCopySubLengths_E_N1_B_N2,
typename InBlockCopyClusterLengths_E_N1_B_N2,
typename InBlockCopyThreadClusterArrangeOrder,
typename InBlockCopySrcAccessOrder,
typename InBlockCopyDstAccessOrder,
index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2,
typename WeiBlockCopySubLengths_E_K,
typename WeiBlockCopyClusterLengths_E_K,
typename WeiBlockCopyThreadClusterArrangeOrder,
typename WeiBlockCopySrcAccessOrder,
typename WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
constexpr auto global_address_space =
integral_constant<AddressSpace, AddressSpace::global>{};
static_assert(ConvDirection == ConvolutionDirection::Forward ||
ConvDirection == ConvolutionDirection::BackwardWeight,
"wrong! this kernel only support convolution forward and backward-weight");
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0,
"wrong!");
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2);
constexpr index_t B = N0 * Ho * Wo;
constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
// input tensor
// global tensor in global memory
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
// global tensor in global memory, src of blockwise copy
constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor(
in_n0_n1_n2_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{},
PassThrough<N1>{},
Merge<Sequence<N0, Ho, Wo>>{},
PassThrough<N2>{}),
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_n1_b_n2_block_desc = make_native_tensor_descriptor_aligned(
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not satisfied");
// input tensor blockwise copy
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(in_e_n1_b_n2_global_desc),
decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
2,
3,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor
// global tensor in global memory, src of blockwise copy
// It is constructed differently, depending on whether forward or backward weight
// convolution
constexpr auto wei_e_k_global_desc =
make_wei_e_k_global_desc_v4r1<ConvDirection>{}(wei_k_c_y_x_global_desc);
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// weight tensor blockwise copy
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc = make_ConstantMatrixDescriptor(
in_e_n1_b_n2_block_desc.GetLength(I0),
in_e_n1_b_n2_block_desc.GetLength(I1) * in_e_n1_b_n2_block_desc.GetLength(I2) *
in_e_n1_b_n2_block_desc.GetLength(I3),
in_e_n1_b_n2_block_desc.GetStride(I0));
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k1_n1n2_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
AccDataType p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_n1n2_thread_mtx_desc, p_out_thread);
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(
p_in_global, p_in_block_double, global_address_space, generic_address_space);
blockwise_wei_copy.Run(
p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store last data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
}
}
// copy output: register to global memory
{
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t K0 = K / K1;
// define output tensor descriptor for threadwise copy
// thread output tensor, src of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
// global output tensor
constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
UnMerge<Sequence<K0, K1>>{},
PassThrough<Ho>{},
PassThrough<Wo>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}, Sequence<6>{}));
// global output tensor, dst of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_global_desc = transform_tensor_descriptor(
out_n0_n1_n2_k0_k1_ho_wo_global_desc,
make_tuple(PassThrough<K0>{},
PassThrough<K1>{},
PassThrough<N1>{},
Merge<Sequence<N0, Ho, Wo>>{},
PassThrough<N2>{}),
make_tuple(Sequence<3>{},
Sequence<4>{},
Sequence<1>{},
Sequence<0, 5, 6>{},
Sequence<2>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(out_k0_k1_n1_b_n2_thread_desc),
decltype(out_k0_k1_n1_b_n2_global_desc),
decltype(
out_k0_k1_n1_b_n2_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 5, 1>::type,
3,
1,
1>({0, 0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
0,
b_thread_data_on_global,
0})
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "convolution_common.hpp"
namespace ck {
template <ConvolutionDirection>
struct make_wei_e_k_global_desc_v4r1_deprecated;
template <>
struct make_wei_e_k_global_desc_v4r1_deprecated<ConvolutionDirection::Forward>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
return WeiDesc::Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
}
};
template <>
struct make_wei_e_k_global_desc_v4r1_deprecated<ConvolutionDirection::BackwardWeight>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
return make_ConstantMergedTensorDescriptor(
WeiDesc::Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{});
}
};
// define B = merge(N0, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
ConvolutionDirection ConvDirection,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmNRepeat,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_N1_B_N2,
class InBlockCopyClusterLengths_E_N1_B_N2,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
constexpr auto global_address_space =
integral_constant<AddressSpace, AddressSpace::global>{};
static_assert(ConvDirection == ConvolutionDirection::Forward ||
ConvDirection == ConvolutionDirection::BackwardWeight,
"wrong! this kernel only support convolution forward and backward-weight");
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0,
"wrong!");
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2);
constexpr index_t B = N0 * Ho * Wo;
constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc),
Sequence<0, 1, 2>{},
Sequence<4>{},
Sequence<3, 6, 7>{},
Sequence<5>{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not satisfied");
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2_deprecated<
BlockSize,
decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
2,
3,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor
// Iensor descriptor in device memory, src of blockwise copy
// It is constructed differently, depending on whether forward or backward weight
// convolution
constexpr auto wei_e_k_global_desc =
make_wei_e_k_global_desc_v4r1_deprecated<ConvDirection>{}(wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc =
make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k1_n1n2_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
AccDataType p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_n1n2_thread_mtx_desc, p_out_thread);
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(
p_in_global, p_in_block_double, global_address_space, generic_address_space);
blockwise_wei_copy.Run(
p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
// even iteration
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
// copy output: register to global memory
{
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register, src of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
// output memory layout descriptor in device memory
constexpr auto out_n0_n1_n2_k0_k1_h_w_global_mem_desc =
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}).Fold(I0, Number<N1>{}, Number<N2>{});
// output merged global tensor descriptor, dst of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_global_merged_desc =
make_ConstantMergedTensorDescriptor(out_n0_n1_n2_k0_k1_h_w_global_mem_desc,
Sequence<3>{},
Sequence<4>{},
Sequence<1>{},
Sequence<0, 5, 6>{},
Sequence<2>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
decltype(out_k0_k1_n1_b_n2_thread_mem_desc),
decltype(out_k0_k1_n1_b_n2_global_merged_desc),
decltype(out_k0_k1_n1_b_n2_thread_mem_desc.GetLengths()),
arithmetic_sequence_gen<0, 5, 1>::type,
arithmetic_sequence_gen<0, 5, 1>::type,
3,
3,
1,
1>({0, 0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
0,
b_thread_data_on_global,
0})
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
}
}
};
} // namespace ck
#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_FP16_BFP16_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_FP16_BFP16_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "implicitgemm_params.hpp" #include "implicitgemm_params.hpp"
namespace ck { namespace ck {
...@@ -173,12 +173,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds ...@@ -173,12 +173,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2_deprecated<
BlockSize, BlockSize,
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc), decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
NormalTensorCoordinate<decltype(in_e_b_block_desc)>,
decltype(in_e_b_block_desc.GetLengths()), decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B, InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B, InBlockCopyClusterLengths_E_B,
...@@ -209,12 +207,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds ...@@ -209,12 +207,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2_deprecated<
BlockSize, BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>,
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K, WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopyClusterLengths_E_K,
...@@ -300,22 +296,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds ...@@ -300,22 +296,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) { static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) {
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True); blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0>{}, True);
}).Else([&](auto fwd) { }).Else([&](auto ) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
}); });
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
const typename vector_type<Float, EPack>::MemoryType* p_a_block_vec = const typename vector_type<Float, EPack>::MemoryType* p_a_block_vec =
...@@ -327,29 +322,29 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds ...@@ -327,29 +322,29 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) { static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) {
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True); blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0>{}, True);
}).Else([&](auto fwd) { }).Else([&](auto ) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
}); });
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are // Vectorize the pointer to match with how half/bfloat16 datatypes are
...@@ -368,10 +363,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds ...@@ -368,10 +363,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -426,11 +421,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds ...@@ -426,11 +421,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col; b_block_data_on_global + c_thread_mtx_on_block.col;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1< auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
decltype(out_k0_k1_k2_b_thread_desc), decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc), decltype(out_k0_k1_k2_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_k2_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_k2_b_global_desc)>,
OutThreadCopySliceLengths, OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
......
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KC1X1_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KC1X1_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "implicitgemm_params.hpp" #include "implicitgemm_params.hpp"
namespace ck { namespace ck {
...@@ -131,21 +131,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu ...@@ -131,21 +131,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc), decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>, decltype(in_e_b_block_desc.GetLengths()),
NormalTensorCoordinate<decltype(in_e_b_block_desc)>, InBlockCopySubLengths_E_B,
decltype(in_e_b_block_desc.GetLengths()), InBlockCopyClusterLengths_E_B,
InBlockCopySubLengths_E_B, InBlockCopyThreadClusterArrangeOrder,
InBlockCopyClusterLengths_E_B, InBlockCopySrcAccessOrder,
InBlockCopyThreadClusterArrangeOrder, InBlockCopyDstAccessOrder,
InBlockCopySrcAccessOrder, 1,
InBlockCopyDstAccessOrder, 1,
1, InBlockCopyDataPerAccess_B,
1, InBlockCopyDataPerAccess_B>(
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0}); {0, b_block_data_on_global}, {0, 0});
// weight tensor // weight tensor
...@@ -167,22 +165,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu ...@@ -167,22 +165,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_wei_copy =
BlockSize, BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>, decltype(wei_e_k_block_desc.GetLengths()),
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>, WeiBlockCopySubLengths_E_K,
decltype(wei_e_k_block_desc.GetLengths()), WeiBlockCopyClusterLengths_E_K,
WeiBlockCopySubLengths_E_K, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopySrcAccessOrder,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcAccessOrder, 0,
WeiBlockCopyDstAccessOrder, 1,
0, WeiBlockCopySrcDataPerRead_E,
1, WeiBlockCopyDstDataPerWrite_K>(
WeiBlockCopySrcDataPerRead_E, {0, k_block_data_on_global}, {0, 0});
WeiBlockCopyDstDataPerWrite_K>({0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
...@@ -253,51 +250,50 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu ...@@ -253,51 +250,50 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0]; p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0];
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0]; p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0];
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -373,11 +369,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu ...@@ -373,11 +369,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col; b_block_data_on_global + c_thread_mtx_on_block.col;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1< auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
decltype(out_k0_k1_k2_b_thread_desc), decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc), decltype(out_k0_k1_k2_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_k2_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_k2_b_global_desc)>,
OutThreadCopySliceLengths, OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
......
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "implicitgemm_params.hpp" #include "implicitgemm_params.hpp"
namespace ck { namespace ck {
...@@ -80,8 +80,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -80,8 +80,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
if(blockIdx.x*blockDim.x + threadIdx.x == 0)
printf("conv dir %d",conv_dir);
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -162,21 +160,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -162,21 +160,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc), decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>, decltype(in_e_b_block_desc.GetLengths()),
NormalTensorCoordinate<decltype(in_e_b_block_desc)>, InBlockCopySubLengths_E_B,
decltype(in_e_b_block_desc.GetLengths()), InBlockCopyClusterLengths_E_B,
InBlockCopySubLengths_E_B, InBlockCopyThreadClusterArrangeOrder,
InBlockCopyClusterLengths_E_B, InBlockCopySrcAccessOrder,
InBlockCopyThreadClusterArrangeOrder, InBlockCopyDstAccessOrder,
InBlockCopySrcAccessOrder, 1,
InBlockCopyDstAccessOrder, 1,
1, InBlockCopyDataPerAccess_B,
1, InBlockCopyDataPerAccess_B>(
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0}); {0, b_block_data_on_global}, {0, 0});
// weight tensor // weight tensor
...@@ -185,7 +181,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -185,7 +181,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
make_WeiDesc_Xdlops<conv_dir, decltype(wei_k_c_y_x_global_desc)>{}.get( make_WeiDesc_Xdlops<conv_dir, decltype(wei_k_c_y_x_global_desc)>{}.get(
wei_k_c_y_x_global_desc); wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{}, Sequence<EPerBlock, KPerBlock>{},
...@@ -194,22 +190,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -194,22 +190,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_wei_copy =
BlockSize, BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>, decltype(wei_e_k_block_desc.GetLengths()),
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>, WeiBlockCopySubLengths_E_K,
decltype(wei_e_k_block_desc.GetLengths()), WeiBlockCopyClusterLengths_E_K,
WeiBlockCopySubLengths_E_K, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopySrcAccessOrder,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcAccessOrder, 0,
WeiBlockCopyDstAccessOrder, 1,
0, WeiBlockCopySrcDataPerRead_E,
1, WeiBlockCopyDstDataPerWrite_K>(
WeiBlockCopySrcDataPerRead_E, {0, k_block_data_on_global}, {0, 0});
WeiBlockCopyDstDataPerWrite_K>({0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
...@@ -261,6 +256,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -261,6 +256,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
} }
#if 1
// LDS double buffer: main body // LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock) e_block_data_begin += 2 * EPerBlock)
...@@ -280,59 +276,51 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -280,59 +276,51 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) { blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
} }
} }
#endif
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) { blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -384,11 +372,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -384,11 +372,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col; b_block_data_on_global + c_thread_mtx_on_block.col;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1< auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
decltype(out_k0_k1_k2_b_thread_desc), decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc), decltype(out_k0_k1_k2_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_k2_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_k2_b_global_desc)>,
OutThreadCopySliceLengths, OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP #define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "tensor_descriptor.hpp"
namespace ck { namespace ck {
...@@ -31,6 +32,11 @@ struct ConstantMatrixDescriptor ...@@ -31,6 +32,11 @@ struct ConstantMatrixDescriptor
return irow * RowStride_ + icol; return irow * RowStride_ + icol;
} }
__host__ __device__ static index_t CalculateOffset(index_t irow, index_t icol)
{
return GetOffsetFromMultiIndex(irow, icol);
}
template <index_t SubNRow, index_t SubNCol> template <index_t SubNRow, index_t SubNCol>
__host__ __device__ static constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>, __host__ __device__ static constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
Number<SubNCol>) Number<SubNCol>)
...@@ -52,10 +58,22 @@ __host__ __device__ constexpr auto ...@@ -52,10 +58,22 @@ __host__ __device__ constexpr auto
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{}; return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
} }
template <class... Ts> template <typename... Ts>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorDescriptor<Ts...>) __host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>)
{
using TDesc = ConstantTensorDescriptor_deprecated<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
static_assert(TDesc::GetStrides()[1] == 1, "wrong");
return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
TDesc::GetLengths()[1],
TDesc::GetStrides()[0]>{};
}
template <typename... Ts>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(NativeTensorDescriptor<Ts...>)
{ {
using TDesc = ConstantTensorDescriptor<Ts...>; using TDesc = NativeTensorDescriptor<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong"); static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
static_assert(TDesc::GetStrides()[1] == 1, "wrong"); static_assert(TDesc::GetStrides()[1] == 1, "wrong");
return ConstantMatrixDescriptor<TDesc::GetLengths()[0], return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
...@@ -63,7 +81,7 @@ __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorD ...@@ -63,7 +81,7 @@ __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorD
TDesc::GetStrides()[0]>{}; TDesc::GetStrides()[0]>{};
} }
template <class TDesc> template <typename TDesc>
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s) __host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
{ {
printf( printf(
......
#ifndef CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#define CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
namespace ck {
// OriginalTensorDesc : ConstantTensorDescriptor_deprecated<...>
// it's the tensor whose dimensions are to be merged
// OriginalDimMergeSeqs : Sequence<...>...
// each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
struct ConstantMergedTensorDescriptor_deprecated
{
using Type = ConstantMergedTensorDescriptor_deprecated;
static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};
static constexpr index_t nDim = sizeof...(OriginalDimMergeSeqs);
static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
__host__ __device__ constexpr ConstantMergedTensorDescriptor_deprecated()
{
static_assert(nDim <= nOriginalDim, "wrong!");
// TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
// OriginalTensorDesc::nDim number of dimensions
// TODO: check OriginalDimMergeSeqs contains all original dimensions
// TODO: check there is no duplication in OriginalDimMergeSeqs
}
__host__ __device__ static constexpr auto GetOriginalTensorDescriptor()
{
return OriginalTensorDesc{};
}
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
template <index_t IDim>
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
{
return std::get<IDim>(mOriginalDimMergeSeqs);
}
template <index_t IDim>
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
{
return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{
constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
{
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
"wrong! stride of a merged dimension is undefined");
constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
return OriginalTensorDesc::GetStride(Number<idim_original>{});
}
// this is a hack to return the stride of the last original dimension of a merged dimension
// TODO: refactor this once the concept of "dimension" is used
template <index_t IDim>
__host__ __device__ static constexpr auto GetLastOriginalDimensionStride(Number<IDim>)
{
constexpr auto idim_last_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
return OriginalTensorDesc::GetStride(Number<idim_last_original>{});
}
__host__ __device__ static constexpr auto GetLengths()
{
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
}
__host__ __device__ static constexpr auto GetElementSize()
{
return OriginalTensorDesc::GetElementSize();
}
template <class OriginalDimsPartial>
struct lambda_1_GetOriginalMultiIndexFromMultiIndex
{
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial;
Array<index_t, nOriginalDim>& original_multi_id;
__host__ __device__ constexpr lambda_1_GetOriginalMultiIndexFromMultiIndex(
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial_,
Array<index_t, nOriginalDim>& original_multi_id_)
: original_multi_id_partial(original_multi_id_partial_),
original_multi_id(original_multi_id_)
{
}
template <index_t I>
__host__ __device__ constexpr void operator()(Number<I>) const
{
constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{});
index_t itmp = original_multi_id_partial[I];
original_multi_id(idim_original) = itmp;
}
};
struct lambda_0_GetOriginalMultiIndexFromMultiIndex
{
const Array<index_t, nDim>& multi_id;
Array<index_t, nOriginalDim>& original_multi_id;
__host__ __device__ constexpr lambda_0_GetOriginalMultiIndexFromMultiIndex(
const Array<index_t, nDim>& multi_id_, Array<index_t, nOriginalDim>& original_multi_id_)
: multi_id(multi_id_), original_multi_id(original_multi_id_)
{
}
template <index_t IDim>
__host__ __device__ constexpr void operator()(Number<IDim>) const
{
constexpr auto original_dims_partial = std::get<IDim>(Type::mOriginalDimMergeSeqs);
// get partial original-multi-id corresponding to this merged dimension
const auto original_multi_id_partial =
OriginalTensorDesc::Extract(original_dims_partial)
.GetMultiIndexFrom1dIndex(multi_id[IDim]);
static_for<0, original_dims_partial.GetSize(), 1>{}(
lambda_1_GetOriginalMultiIndexFromMultiIndex<decltype(original_dims_partial)>(
original_multi_id_partial, original_multi_id));
}
};
// return type is Array<...>
__host__ __device__ static constexpr auto
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
{
Array<index_t, nOriginalDim> original_multi_id;
static_for<0, nDim, 1>{}(
lambda_0_GetOriginalMultiIndexFromMultiIndex(multi_id, original_multi_id));
return original_multi_id;
}
template <index_t... Is>
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
{
constexpr auto multi_id = sequence2array(Sequence<Is...>{});
constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
}
__host__ __device__ static constexpr index_t
GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
{
auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
}
template <class... Is>
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
{
return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
}
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
{
constexpr auto packed_desc = make_ConstantTensorDescriptor_packed(GetLengths());
return packed_desc.GetMultiIndexFrom1dIndex(id);
}
__host__ __device__ static constexpr auto Pack()
{
constexpr auto lengths = GetLengths();
constexpr auto strides = calculate_tensor_strides_packed(lengths);
return ConstantTensorDescriptor_deprecated<decltype(lengths), decltype(strides)>{};
}
};
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
OriginalDimMergeSeqs...)
{
return ConstantMergedTensorDescriptor_deprecated<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
}
template <class TDesc>
__host__ __device__ void print_ConstantMergedTensorDescriptor(const char* s, TDesc)
{
print_ConstantTensorDescriptor(s, TDesc::GetOriginalTensorDescriptor());
}
} // namespace ck
#endif
#ifndef CK_CONSTANT_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#define CK_CONSTANT_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#include "common_header.hpp"
namespace ck {
template <class Lengths>
__host__ __device__ constexpr auto calculate_tensor_strides_packed_deprecated(Lengths)
{
return reverse_inclusive_scan_sequence(
Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
}
template <class Lengths, index_t Align>
__host__ __device__ constexpr auto calculate_tensor_strides_aligned_old(Lengths, Number<Align>)
{
constexpr index_t L_back_align =
Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
return calculate_tensor_strides_packed_deprecated(
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
}
template <class Lengths, class Strides>
struct ConstantTensorDescriptor_deprecated
{
using Type = ConstantTensorDescriptor_deprecated;
static constexpr index_t nDim = Lengths::GetSize();
__host__ __device__ constexpr ConstantTensorDescriptor_deprecated()
{
static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
}
__host__ __device__ static constexpr auto GetOriginalTensorDescriptor() { return Type{}; }
template <index_t IDim>
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
{
return Sequence<IDim>{};
}
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
__host__ __device__ static constexpr auto GetStrides() { return Strides{}; }
__host__ __device__ static constexpr auto GetLength(index_t IDim) { return Lengths{}[IDim]; }
__host__ __device__ static constexpr auto GetStride(index_t IDim) { return Strides{}[IDim]; }
struct lambda_AreDimensionsContinuous
{
bool& is_continuous;
__host__ __device__ constexpr lambda_AreDimensionsContinuous(bool& is_continuous_)
: is_continuous(is_continuous_)
{
}
template <index_t IDim_>
__host__ __device__ constexpr void operator()(Number<IDim_>) const
{
constexpr auto IDim = Number<IDim_>{};
constexpr auto IDim_p1 = Number<IDim_ + 1>{};
is_continuous =
is_continuous && (GetStride(IDim) >= GetStride(IDim_p1) &&
GetStride(IDim) == GetStride(IDim_p1) * GetLength(IDim_p1));
}
};
__host__ __device__ static constexpr bool AreDimensionsContinuous()
{
bool is_continuous = true;
static_for<0, nDim - 1, 1>{}(lambda_AreDimensionsContinuous(is_continuous));
return is_continuous;
}
__host__ __device__ static constexpr bool IsPackedTensor()
{
return AreDimensionsContinuous() && GetStride(Number<nDim - 1>{}) == 1;
}
template <class T>
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(T)
{
return false;
}
__host__ __device__ static constexpr auto GetElementSize()
{
return Number<reduce_on_sequence(Lengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
}
__host__ __device__ static constexpr auto GetElementSpace()
{
constexpr index_t element_space_unaligned = reduce_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
return Number<element_space_unaligned>{};
}
// emulate constexpr lambda
template <index_t NSize>
struct lambda_GetOffsetFromMultiIndex
{
Array<index_t, NSize>& multi_id;
index_t& offset;
__host__
__device__ constexpr lambda_GetOffsetFromMultiIndex(Array<index_t, NSize>& multi_id_,
index_t& offset_)
: multi_id(multi_id_), offset(offset_)
{
}
template <class X>
__host__ __device__ constexpr void operator()(X IDim) const
{
offset += multi_id[IDim] * Type::GetStride(IDim);
}
};
template <index_t NSize>
__host__ __device__ static constexpr index_t
GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
{
static_assert(NSize == nDim, "wrong! Dimension not consistent");
index_t offset = 0;
static_for<0, nDim, 1>{}(lambda_GetOffsetFromMultiIndex<NSize>(multi_id, offset));
return offset;
}
template <class... Is>
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
{
return GetOffsetFromMultiIndex(Array<index_t, sizeof...(Is)>{is...});
}
template <index_t... Is>
__host__ __device__ static constexpr auto GetOffsetFromMultiIndex(Sequence<Is...>)
{
static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");
constexpr auto multi_id = Sequence<Is...>{};
return Number<reduce_on_sequence(
multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{})>{};
}
// emulate constexpr lambda
template <class PackedStrides>
struct lambda_GetMultiIndexFrom1dIndex
{
index_t& id;
Array<index_t, nDim>& multi_id;
__host__
__device__ constexpr lambda_GetMultiIndexFrom1dIndex(index_t& id_,
Array<index_t, nDim>& multi_id_)
: id(id_), multi_id(multi_id_)
{
}
template <class IDim_>
__host__ __device__ constexpr void operator()(IDim_) const
{
constexpr auto IDim = IDim_{};
constexpr index_t stride = PackedStrides::Get(IDim);
multi_id(IDim) = id / stride;
id -= multi_id[IDim] * stride;
}
};
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
{
Array<index_t, nDim> multi_id;
using PackedStrides = decltype(calculate_tensor_strides_packed_deprecated(GetLengths()));
// calculate index in each of the dimensions in the order of their dimension
static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id));
multi_id(Number<nDim - 1>{}) = id / PackedStrides::Get(Number<nDim - 1>{});
return multi_id;
}
__host__ __device__ static constexpr auto
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
{
return multi_id;
}
// This function doesn't do carry check on the highest dimension for positive stepping (or
// borrow check on the highest dimension for negative stepping) , for performance reason. It is
// the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound on the
// highest dimension for positive stepping (or on the lowest dimension for negative stepping)
template <bool PositiveDirection>
__host__ __device__ static Array<index_t, nDim>
UpdateMultiIndexGivenStepSizeOf1dIndex(Array<index_t, nDim> old_multi_id,
index_t step_size_of_1d_index,
integral_constant<bool, PositiveDirection>)
{
Array<index_t, nDim> new_multi_id;
const auto step_sizes = GetMultiIndexFrom1dIndex(step_size_of_1d_index);
static_if<PositiveDirection>{}([&](auto) {
new_multi_id = old_multi_id + step_sizes;
bool carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDim, 1>{}([&](auto IDimReverse) {
constexpr index_t idim = nDim - 1 - IDimReverse;
constexpr auto IDim = Number<idim>{};
if(carry)
{
++new_multi_id(idim);
}
carry = false;
if(new_multi_id[idim] >= GetLength(IDim))
{
new_multi_id(idim) -= GetLength(IDim);
carry = true;
}
});
}).Else([&](auto) {
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
new_multi_id = old_multi_id + (GetLengths() - step_sizes);
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDim, 1>{}([&](auto IDimReverse) {
constexpr index_t idim = nDim - 1 - IDimReverse;
constexpr auto IDim = Number<idim>{};
if(borrow)
{
--new_multi_id(idim);
}
borrow = false;
if(new_multi_id[idim] < GetLength(IDim))
{
new_multi_id(idim) += GetLength(IDim);
borrow = true;
}
});
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
new_multi_id = new_multi_id - GetLengths();
});
return new_multi_id;
}
template <index_t... IDims>
__host__ __device__ static constexpr auto Extract(Number<IDims>... extract_dims)
{
static_assert(sizeof...(IDims) <= GetNumOfDimension(),
"wrong! too many number of dimensions to be extracted");
using extract_lengths = decltype(Lengths::Extract(extract_dims...));
using extract_strides = decltype(Strides::Extract(extract_dims...));
return ConstantTensorDescriptor_deprecated<extract_lengths, extract_strides>{};
}
template <index_t... IDims>
__host__ __device__ static constexpr auto Extract(Sequence<IDims...>)
{
return Extract(Number<IDims>{}...);
}
template <class... Ts>
__host__ __device__ static constexpr auto Embed(ConstantTensorDescriptor_deprecated<Ts...>)
{
using leaf_tensor = ConstantTensorDescriptor_deprecated<Ts...>;
return ConstantTensorDescriptor_deprecated<
decltype(GetLengths().PushBack(leaf_tensor::GetLengths())),
decltype(GetStrides().PushBack(leaf_tensor::GetStrides()))>{};
}
template <index_t IDimVector, index_t DataPerVector>
struct lambda_IsVectorizationAllowed
{
bool& is_allowed;
__host__ __device__ constexpr lambda_IsVectorizationAllowed(bool& is_allowed_)
: is_allowed(is_allowed_)
{
}
template <index_t IDim_>
__host__ __device__ constexpr void operator()(Number<IDim_>) const
{
constexpr auto IDim = Number<IDim_>{};
if(IDimVector != IDim && Strides::Get(IDim) % DataPerVector != 0)
{
is_allowed = false;
}
}
};
template <index_t IDimVector, index_t DataPerVector>
__host__ __device__ static constexpr bool IsVectorizationAllowed(Number<IDimVector>,
Number<DataPerVector>)
{
bool is_allowed = (Strides{}[IDimVector] == 1 || DataPerVector == 1) &&
Lengths{}[IDimVector] % DataPerVector == 0;
static_for<0, nDim, 1>{}(
lambda_IsVectorizationAllowed<IDimVector, DataPerVector>{is_allowed});
return is_allowed;
}
template <index_t IDim, index_t DataPerVector>
__host__ __device__ static constexpr auto Vectorize(Number<IDim>, Number<DataPerVector>)
{
constexpr auto idim = Number<IDim>{};
constexpr auto data_per_vector = Number<DataPerVector>{};
static_assert(IsVectorizationAllowed(idim, data_per_vector), "wrong!");
using vectorized_lengths =
decltype(Lengths::Modify(Number<IDim>{}, Number<Lengths{}[IDim] / DataPerVector>{}));
using vectorized_strides =
decltype((Strides{} / Number<DataPerVector>{}).Modify(Number<IDim>{}, Number<1>{}));
return ConstantTensorDescriptor_deprecated<vectorized_lengths, vectorized_strides>{};
}
template <index_t IDim, index_t SliceLen>
__host__ __device__ static constexpr auto Slice(Number<IDim>, Number<SliceLen>)
{
using slice_lengths = decltype(Lengths::Modify(Number<IDim>{}, Number<SliceLen>{}));
return ConstantTensorDescriptor_deprecated<slice_lengths, Strides>{};
}
template <index_t... Is>
__host__ __device__ static constexpr auto Slice(Sequence<Is...> slice_lengths)
{
static_assert(slice_lengths.GetSize() == nDim, "wrong!");
return ConstantTensorDescriptor_deprecated<decltype(slice_lengths), Strides>{};
}
template <index_t IDim, index_t SliceLength, index_t SliceStride>
__host__ __device__ static constexpr auto
StridedSlice(Number<IDim>, Number<SliceLength>, Number<SliceStride>)
{
constexpr index_t new_stride = Strides::Get(Number<IDim>{}) * SliceStride;
using new_lengths = decltype(Lengths::Modify(Number<IDim>{}, Number<SliceLength>{}));
using new_strides = decltype(Strides::Modify(Number<IDim>{}, Number<new_stride>{}));
return ConstantTensorDescriptor_deprecated<new_lengths, new_strides>{};
}
template <index_t IDim, index_t... FoldIntervals>
__host__ __device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
{
constexpr auto fold_intervals = Sequence<FoldIntervals...>{};
constexpr index_t fold_intervals_product =
reduce_on_sequence(fold_intervals, math::multiplies<index_t>{}, Number<1>{});
constexpr auto unfold_length = GetLength(Number<IDim>{});
constexpr auto unfold_stride = GetStride(Number<IDim>{});
// length of the dimension to be folded needs to be dividable by fold_interval_product,
// otherwise, folding is invalid
static_assert(unfold_length % fold_intervals_product == 0,
"wrong! length on the dimension to be folded cannot be evenly divided!");
// folded lengths
constexpr auto fold_lengths =
Sequence<unfold_length / fold_intervals_product>{}.PushBack(fold_intervals);
// folded strides
constexpr auto fold_strides =
Number<unfold_stride>{} *
reverse_inclusive_scan_sequence(
fold_intervals.PushBack(Number<1>{}), math::multiplies<index_t>{}, Number<1>{});
// left and right
constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::type{};
constexpr auto right =
typename arithmetic_sequence_gen<IDim + 1, GetNumOfDimension(), 1>::type{};
constexpr auto new_lengths =
GetLengths().Extract(left).PushBack(fold_lengths).PushBack(GetLengths().Extract(right));
constexpr auto new_strides =
GetStrides().Extract(left).PushBack(fold_strides).PushBack(GetStrides().Extract(right));
return ConstantTensorDescriptor_deprecated<decltype(new_lengths), decltype(new_strides)>{};
}
template <index_t IDim, index_t... FoldIntervals>
__host__ __device__ static constexpr auto Fold(Number<IDim>, Sequence<FoldIntervals...>)
{
return Fold(Number<IDim>{}, Number<FoldIntervals>{}...);
}
// this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension
template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
__host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
{
static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim &&
FirstUnfoldDim <= LastUnfoldDim,
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!");
// left and right
constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::type{};
constexpr auto middle =
typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::type{};
constexpr auto right =
typename arithmetic_sequence_gen<LastUnfoldDim + 1, GetNumOfDimension(), 1>::type{};
// dimensions to be unfolded need to be continuous
static_assert(Type::Extract(middle).AreDimensionsContinuous(), "wrong! not unfoldable");
// unfolded length, stride
constexpr index_t unfold_length = reduce_on_sequence(
GetLengths().Extract(middle), math::multiplies<index_t>{}, Number<1>{});
constexpr index_t unfold_stride = GetStride(Number<LastUnfoldDim>{});
// new lengths, strides
constexpr auto new_lengths = GetLengths()
.Extract(left)
.PushBack(Number<unfold_length>{})
.PushBack(GetLengths().Extract(right));
constexpr auto new_strides = GetStrides()
.Extract(left)
.PushBack(Number<unfold_stride>{})
.PushBack(GetStrides().Extract(right));
return ConstantTensorDescriptor_deprecated<decltype(new_lengths), decltype(new_strides)>{};
}
__host__ __device__ static constexpr auto Pack()
{
using packed_strides = decltype(calculate_tensor_strides_packed_deprecated(Lengths{}));
return ConstantTensorDescriptor_deprecated<Lengths, packed_strides>{};
}
template <class MapNew2Old>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old)
{
return ConstantTensorDescriptor_deprecated<
decltype(Lengths::ReorderGivenNew2Old(MapNew2Old{})),
decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{};
}
template <class MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{
return ConstantTensorDescriptor_deprecated<
decltype(Lengths::ReorderGivenOld2New(MapOld2New{})),
decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{};
}
};
template <class Lengths>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_packed(Lengths)
{
using Strides = decltype(calculate_tensor_strides_packed_deprecated(Lengths{}));
return ConstantTensorDescriptor_deprecated<Lengths, Strides>{};
}
template <class Lengths, class Strides>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides)
{
return ConstantTensorDescriptor_deprecated<Lengths, Strides>{};
}
template <class Lengths, index_t Align>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
{
using Strides = decltype(calculate_tensor_strides_aligned_old(Lengths{}, Number<Align>{}));
return ConstantTensorDescriptor_deprecated<Lengths, Strides>{};
}
template <index_t... Lengths, index_t... Strides>
__host__ __device__ void print_ConstantTensorDescriptor(
const char* s, ConstantTensorDescriptor_deprecated<Sequence<Lengths...>, Sequence<Strides...>>)
{
constexpr index_t ndim = sizeof...(Lengths);
static_assert(ndim > 0 && ndim <= 12, "wrong!");
static_if<ndim == 1>{}([&](auto) {
printf("%s dim %u, lengths {%u}, strides {%u}\n", s, ndim, Lengths..., Strides...);
});
static_if<ndim == 2>{}([&](auto) {
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, ndim, Lengths..., Strides...);
});
static_if<ndim == 3>{}([&](auto) {
printf(
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, ndim, Lengths..., Strides...);
});
static_if<ndim == 4>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 5>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 6>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 7>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 8>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 9>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 10>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
"%u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 11>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
"%u %u "
"%u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 12>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
"%u %u %u %u "
"%u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
}
} // namespace ck
#endif
#ifndef CK_DIMENSION_HPP
#define CK_DIMENSION_HPP
#include "common_header.hpp"
namespace ck {
template <index_t Length, index_t Stride>
struct NativeDimension
{
__host__ __device__ static constexpr auto GetLength() { return Number<Length>{}; }
__host__ __device__ static constexpr auto GetStride() { return Number<Stride>{}; }
};
} // namespace ck
#endif
#ifndef CK_MULTI_INDEX_TRANSFORM_HPP
#define CK_MULTI_INDEX_TRANSFORM_HPP
#include "common_header.hpp"
namespace ck {
template <index_t N>
using MultiIndex = Array<index_t, N>;
template <typename... Xs>
__host__ __device__ constexpr auto make_multi_index(Xs... xs)
{
return MultiIndex<sizeof...(Xs)>(xs...);
}
template <index_t Length>
struct PassThrough
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<1>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<1>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<Length>{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
return idx_up;
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return idx_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{
return true;
}
};
// LowerLengths: Sequence<...>
template <typename LowerLengths, typename LeftPads, typename RightPads>
struct Pad
{
static constexpr index_t nDim = LowerLengths::Size();
using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return LowerLengths{} + LeftPads{} + RightPads{};
}
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
return idx_up - LeftPads{};
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return idx_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const
{
#if 0
struct lambda_no_pad
{
__host__ __device__ constexpr bool operator()(index_t x) const { return x == 0; }
};
if(sequence_all_of(LeftPads{}, lambda_no_pad{}) &&
sequence_all_of(RightPads{}, lambda_no_pad{}))
{
return true;
}
else
#endif
{
bool flag = true;
static_for<0, nDim, 1>{}([&](auto idim) {
// only check if there is left-padding
static_if<(LeftPads::At(idim) != 0)>{}(
[&](auto) { flag = flag && idx_up[idim] >= LeftPads::At(idim); });
// only check if there is right-padding
static_if<(RightPads::At(idim) != 0)>{}([&](auto) {
flag = flag && (idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
});
});
return flag;
}
}
};
// LowerLengths: Sequence<...>
template <typename LowerLengths>
struct Merge
{
static constexpr index_t nDimLow = LowerLengths::Size();
static constexpr index_t nDimUp = 1;
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return Sequence<reduce_on_sequence(
LowerLengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
}
// emulate constexpr lambda
template <typename PseudoLowStrides>
struct lambda_CalculateLowerIndex
{
index_t& itmp;
LowerIndex& idx_low;
__host__ __device__ explicit constexpr lambda_CalculateLowerIndex(index_t& itmp_,
LowerIndex& idx_low_)
: itmp(itmp_), idx_low(idx_low_)
{
}
template <typename IDim>
__host__ __device__ constexpr void operator()(IDim idim) const
{
constexpr index_t stride = PseudoLowStrides::At(idim);
idx_low(idim) = itmp / stride;
itmp -= idx_low[idim] * stride;
}
};
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low;
index_t itmp = idx_up[0];
constexpr auto pseudo_low_strides =
reverse_inclusive_scan_sequence(
LowerLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
static_for<0, nDimLow - 1, 1>{}(
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
idx_low(nDimLow - 1) = itmp / pseudo_low_strides[nDimLow - 1];
return idx_low;
}
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
// If idx_up_diff is known at compile-time, many calculations can be optimized
// away by compiler
// This function assume idx_low_old is not out-of-bound
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& idx_low_old)
{
// do nothing if idx_up_diff == 0
if(idx_up_diff[0] == 0)
{
return make_zero_array<index_t, nDimLow>();
}
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
// If idx_up_diff is known at compile-time, the calculation can
// be done at compile-time. However, if idx_up_diff is only known
// at run-time, then the calculation will also be computed at
// run-time, and can be very expensive.
LowerIndex idx_low_new = idx_low_old + CalculateLowerIndex(idx_up_diff);
if(idx_up_diff[0] > 0)
{
bool carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
constexpr index_t i = nDimLow - 1 - ireverse;
if(carry)
{
++idx_low_new(i);
}
carry = false;
if(idx_low_new[i] >= LowerLengths::At(i))
{
idx_low_new(i) -= LowerLengths::At(i);
carry = true;
}
});
// highest dimension, no out-of-bound check
if(carry)
{
++idx_low_new(0);
}
}
else if(idx_up_diff[0] < 0)
{
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
constexpr index_t i = nDimLow - 1 - ireverse;
if(borrow)
{
--idx_low_new(i);
}
borrow = false;
if(idx_low_new[i] < 0)
{
idx_low_new(i) += LowerLengths::At(i);
borrow = true;
}
});
// highest dimension, no out-of-bound check
if(borrow)
{
--idx_low_new(0);
}
}
return idx_low_new - idx_low_old;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{
return true;
}
};
// UpperLengths: Sequence<...>
template <typename UpperLengths>
struct UnMerge
{
static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpperLengths::Size();
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low{0};
constexpr auto pseudo_up_strides =
reverse_inclusive_scan_sequence(
UpperLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low(0) += idx_up[idim] * pseudo_up_strides[idim]; });
return idx_low;
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return CalculateLowerIndex(idx_up_diff);
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{
return true;
}
};
// UpperLengths: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template <typename UpperLengths, typename Coefficients>
struct Embed
{
static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpperLengths::Size();
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ explicit constexpr Embed()
{
static_assert(UpperLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
"wrong! # of dimensions not consistent");
}
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low(Coefficients{}[nDimUp]);
static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low(0) += idx_up[idim] * Coefficients{}[idim]; });
return idx_low;
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
LowerIndex idx_low_diff{0};
static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low_diff(0) += idx_up_diff[idim] * Coefficients{}[idim]; });
return idx_low_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{
return true;
}
};
template <index_t LowerLength, index_t VectorSize>
struct Vectorize
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
__host__ __device__ constexpr Vectorize()
{
static_assert(VectorSize > 0 && LowerLength % VectorSize == 0,
"wrong! cannot evenly divide");
}
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<1>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<1>{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return Sequence<LowerLength / VectorSize>{};
}
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
return VectorSize * idx_up;
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return VectorSize * idx_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{
return true;
}
};
} // namespace ck
#endif
...@@ -2,299 +2,248 @@ ...@@ -2,299 +2,248 @@
#define CK_TENSOR_COORDINATE_HPP #define CK_TENSOR_COORDINATE_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "dimension.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "multi_index_transform.hpp"
#include "tensor_descriptor.hpp"
namespace ck { namespace ck {
template <class TensorDesc> // A "tensor cooridnate" is an opaque object that represents a "point of location" inside a tensor
struct NormalTensorCoordinate // At the bare minimun, user should be able to query the following information from a tensor
// coordinate:
// 1. Tensor descriptor
// 2. Location, represented in the form of multi-index
// 3. Location, represented in the form of the offset to the origin of the tensor
// 4. If the location is inside invalid area or not, i.e. the padding area of an implicitly padded
// tensor is considered invalid, because the padding area doesn't have any physical memory
// allocation
// A tensor cooridnate also provides following functionality:
// 1. Given step size in each dimension, update itself, or return a new tensor cooridnate, so user
// can freely move the "point of location" inside the tensor
// wrapper class for NativeTensorCoordinate and TransformedTensorCoordinate
template <typename TensorDesc>
struct TensorCoordinate;
// tensor coordinate for native tensor
template <typename NativeTensorDesc>
struct NativeTensorCoordinate
{ {
using type = NormalTensorCoordinate; using type = NativeTensorCoordinate;
using tensor_desc_type = TensorDesc; using tensor_desc_type = NativeTensorDesc;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension(); static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__host__ __device__ constexpr NormalTensorCoordinate(Array<index_t, nDim> tensor_index) __host__ __device__ constexpr NativeTensorCoordinate(Index idx)
: mOffset{tensor_desc_type::GetOffsetFromMultiIndex(tensor_index)} : mIndex(idx), mOffset(tensor_desc_type::CalculateOffset(idx))
{ {
} }
template <class... Xs> template <typename... Xs>
__host__ __device__ constexpr NormalTensorCoordinate(Xs... xs) __host__ __device__ constexpr NativeTensorCoordinate(Xs... xs)
: NormalTensorCoordinate(Array<index_t, nDim>{xs...}) : NativeTensorCoordinate(Index{xs...})
{ {
} }
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; } template <index_t... Xs>
__host__ __device__ constexpr NativeTensorCoordinate(Sequence<Xs...>)
: NativeTensorCoordinate(Index{Xs...})
{
}
__host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
__host__ __device__ constexpr const Index& GetIndex() const { return mIndex; }
// T is Array or Sequence __host__ __device__ constexpr const index_t& GetOffset() const { return mOffset; }
template <class T>
__host__ __device__ type operator+=(T step_sizes) __host__ __device__ constexpr type operator+=(const Index& idx_diff)
{ {
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!"); // mIndex is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndex += idx_diff;
mOffset += tensor_desc_type::GetOffsetFromMultiIndex(step_sizes); mOffset += tensor_desc_type::CalculateOffsetDiff(idx_diff);
return *this; return *this;
} }
template <class T> __host__ __device__ constexpr type operator-=(const Index& idx_diff)
__host__ __device__ type operator-=(T step_sizes)
{ {
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!"); // mIndex is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndex -= idx_diff;
mOffset -= tensor_desc_type::GetOffsetFromMultiIndex(step_sizes); mOffset -= tensor_desc_type::CalculateOffsetDiff(idx_diff);
return *this; return *this;
} }
template <class T> __host__ __device__ constexpr type operator+(const Index& idx_diff) const
__host__ __device__ constexpr type operator+(T step_sizes) const
{ {
type coord = *this; type coord = *this;
coord += step_sizes; coord += idx_diff;
return coord; return coord;
} }
template <class T> __host__ __device__ constexpr type operator-(const Index& idx_diff) const
__host__ __device__ constexpr type operator-(T step_sizes) const
{ {
type coord = *this; type coord = *this;
coord -= step_sizes; coord -= idx_diff;
return coord; return coord;
} }
// reposition point of origin, and return compensated offset. __host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff)
// This is a hack to reduce index calculation during looping over
// a tensor whose origin is this TensorCoordinate. It does so, by spitting
// out the run-time offset to the pointer (to the tensor data) held by this
// TensorCoordiante, so the caller can add the offset into the run-time pointer of
// the data, so only 1 run-time variable (update pointer) is needed, instead
// of 2 run-time variables (old pointer and this offset)
// TODO: after introducing the concept of "run-time tensor view", which contains the
// run-time pointer to the data, always keep track of the pointer, instead of both
// offset and the pointer. This also bring additional benefit that we don't need to
// worry the offset might underflow (because offset is unsigned integer) when updating it.
__host__ __device__ constexpr index_t RepositOrigin()
{ {
index_t offset_diff = mOffset; return tensor_desc_type::CalculateOffsetDiff(idx_diff);
mOffset = 0;
return offset_diff;
} }
__host__ __device__ static constexpr bool IsUpperIndexMappedToValidOffset() { return true; }
private: private:
// mIndex may be saved and updated, however, the value of some (or all) of its entries may
// never be used. Compiler should be able to remove these entries as well as its calculation
// as dead code.
// TODO: make sure compiler indeed remove these dead code
Index mIndex;
index_t mOffset; index_t mOffset;
}; };
template <class TensorDesc> // tensor coordinate for transformed tensor
struct MergedTensorCoordinate template <typename TransformedTensorDesc>
struct TransformedTensorCoordinate
{ {
using type = MergedTensorCoordinate; using tensor_desc_type = TransformedTensorDesc;
using tensor_desc_type = TensorDesc; using LowerCoord =
typename TensorCoordinate<decltype(tensor_desc_type::GetLowerTensorDescriptor())>::type;
using UpperCoord = TransformedTensorCoordinate;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension(); static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
static constexpr index_t nOriginalDim = using UpperIndex = MultiIndex<nDim>;
tensor_desc_type::GetOriginalTensorDescriptor().GetNumOfDimension();
__host__ __device__ constexpr MergedTensorCoordinate(Array<index_t, nDim> tensor_index) __host__ __device__ constexpr TransformedTensorCoordinate(UpperIndex idx)
: mOriginalIndex{tensor_desc_type::GetOriginalMultiIndexFromMultiIndex(tensor_index)} : mIndexUp{idx}, mCoordLow{tensor_desc_type::CalculateLowerIndex(idx)}
{ {
// partial offset on each dimension }
static_for<0, nDim, 1>{}([&](auto idim) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
mPartialOffsets(idim) = partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mOriginalIndex, partial_original_dims));
});
// complete offset template <typename... Xs>
mOffset = __host__ __device__ constexpr TransformedTensorCoordinate(Xs... xs)
accumulate_on_array(mPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0)); : TransformedTensorCoordinate(UpperIndex{xs...})
{
} }
template <class... Xs> template <index_t... Xs>
__host__ __device__ constexpr MergedTensorCoordinate(Xs... xs) __host__ __device__ constexpr TransformedTensorCoordinate(Sequence<Xs...>)
: MergedTensorCoordinate(Array<index_t, nDim>{xs...}) : TransformedTensorCoordinate(UpperIndex{Xs...})
{ {
} }
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; } __host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
__host__ __device__ constexpr const LowerCoord& GetLowerCoordinate() const { return mCoordLow; }
__host__ __device__ constexpr const UpperIndex& GetUpperIndex() const { return mIndexUp; }
template <class IDim, class T, bool PositiveDirection> __host__ __device__ constexpr const UpperIndex& GetIndex() const { return GetUpperIndex(); }
__host__ __device__ void
MoveOnDimension(IDim idim_, T step_size, integral_constant<bool, PositiveDirection>) __host__ __device__ constexpr const index_t& GetOffset() const
{ {
constexpr auto idim = idim_; return GetLowerCoordinate().GetOffset();
// if step_size is known at compile time
static_if<is_static<T>::value>{}(
[&](auto) { static_if<T{} == 0>{}([&](auto) { return; }); });
// update original index
static_if<tensor_desc_type::ContainMultipleOriginalDimensions(idim)>{}([&](auto) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr index_t ndim_partial_original = partial_original_dims.GetSize();
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
const auto partial_original_step_sizes =
partial_original_desc.GetMultiIndexFrom1dIndex(step_size);
// update partial original multi-id
auto partial_original_id = extract_array(mOriginalIndex, partial_original_dims);
static_if<PositiveDirection>{}([&](auto) {
partial_original_id += partial_original_step_sizes;
bool carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(carry)
{
++partial_original_id(i);
}
carry = false;
if(partial_original_id[i] >= partial_original_desc.GetLength(i))
{
partial_original_id(i) -= partial_original_desc.GetLength(i);
carry = true;
}
});
}).Else([&](auto) {
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
partial_original_id +=
partial_original_desc.GetLengths() - partial_original_step_sizes;
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(borrow)
{
--partial_original_id(i);
}
borrow = false;
if(partial_original_id[i] < partial_original_desc.GetLength(i))
{
partial_original_id(i) += partial_original_desc.GetLength(i);
borrow = true;
}
});
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
partial_original_id = partial_original_id - partial_original_desc.GetLengths();
});
// update "mOriginalIndex"
static_for<0, ndim_partial_original, 1>{}([&](auto I) {
constexpr auto idim_original = partial_original_dims[I];
mOriginalIndex(idim_original) = partial_original_id[I];
});
// calculate new partial offset on this merged dimension
const index_t old_partial_offset = mPartialOffsets[idim];
mPartialOffsets(idim) =
partial_original_desc.GetOffsetFromMultiIndex(partial_original_id);
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mOffset = (mOffset + mPartialOffsets[idim]) - old_partial_offset;
}).Else([&](auto fwd) {
static_if<PositiveDirection>{}([&](auto) {
mOffset += step_size * fwd(tensor_desc_type{}).GetStride(idim);
}).Else([&](auto) { mOffset -= step_size * fwd(tensor_desc_type{}).GetStride(idim); });
});
} }
// T is Array or Sequence __host__ __device__ constexpr UpperCoord operator+=(const UpperIndex& idx_up_diff)
template <class T>
__host__ __device__ type operator+=(T step_sizes)
{ {
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, // For transformation of multi-index difference, not all transformation functions need to
"wrong! the rank of step size doesn't match with that of tensor coordinate"); // know the old lower-index or the old upper-index. We pass both of them to the
// transformation function. The transformation function itself decides to use them or not.
mCoordLow += tensor_desc_type::CalculateLowerIndexDiff(
idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex());
static_for<0, nDim, 1>{}([&](auto idim) { // mIndexUp is updated here, but some (or all) of its entries may never be used
if(step_sizes[idim] != 0) // compiler should remove those entries as dead code
{ mIndexUp += idx_up_diff;
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, true>{});
}
});
return *this; return *this;
} }
template <class T> __host__ __device__ constexpr UpperCoord operator-=(const UpperIndex& idx_up_diff)
__host__ __device__ type operator-=(T step_sizes)
{ {
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, mCoordLow -= tensor_desc_type::CalculateLowerIndexDiff(
"wrong! the rank of step size doesn't match with that of tensor coordinate"); idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex());
static_for<0, nDim, 1>{}([&](auto idim) { // mIndex is updated here, but some (or all) of its entries may never be used
if(step_sizes[idim] != 0) // compiler should remove those entries as dead code
{ mIndexUp -= idx_up_diff;
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, false>{});
}
});
return *this; return *this;
} }
template <class T> __host__ __device__ constexpr UpperCoord operator+(const UpperIndex& idx_up_diff) const
__host__ __device__ constexpr type operator+(T step_sizes) const
{ {
type coord = *this; UpperCoord coord_up = *this;
coord += step_sizes; coord_up += idx_up_diff;
return coord; return coord_up;
} }
template <class T> __host__ __device__ constexpr UpperCoord operator-(const UpperIndex& idx_up_diff) const
__host__ __device__ constexpr type operator-(T step_sizes) const
{ {
type coord = *this; UpperCoord coord_up = *this;
coord -= step_sizes; coord_up -= idx_up_diff;
return coord; return coord_up;
} }
__host__ __device__ static constexpr index_t RepositOrigin() { return 0; } // Calculate offset diff without updating tensor-coordinate
// If idx_up_diff is know at compile time, and has only non-zero entries on linear dimensions,
// then all calculation can be done at compile-time.
// TODO: this function is not compiled to expected ISA
__host__ __device__ constexpr index_t CalculateOffsetDiff(const UpperIndex& idx_up_diff) const
{
// For transformation of multi-index difference, not all transformation functions need to
// know the old lower-index or the old upper-index. We pass both of them to the
// transformation function. The transformation function itself decides to use them or not.
const auto idx_low_diff = tensor_desc_type::CalculateLowerIndexDiff(
idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex());
return GetLowerCoordinate().CalculateOffsetDiff(idx_low_diff);
}
__host__ __device__ constexpr bool IsUpperIndexMappedToValidOffset() const
{
return tensor_desc_type::IsUpperIndexMappedToValidLowerIndex(GetIndex()) &&
mCoordLow.IsUpperIndexMappedToValidOffset();
}
private: private:
// Allocate register memory for all merged dimensions and normal dimensions. // mIndexUp may be calculated and updated, however, the value of some (or all) of its entries
// However, only those merged dimensions, whose index will be involved in arithmetic // may
// after the construction of this TensorCoordinate (e.g. when user move a slicing // never be used. Compiler should be able to remove these entries as well as its calculation
// window on the merged dimension), will use these register memory. // as dead code.
// Let's hope compiler will optimize away those register memory allocated for normal // TODO: make sure compiler indeed remove these dead code
// dimensions, and those merged dimensions, that would never be involved in index UpperIndex mIndexUp;
// arithmetic after construction of TensorCoordinate. LowerCoord mCoordLow;
// TODO: refactor TensorCoordinate, after introducing the concept of "dimensions" };
// and simplify implementation of ConstantMergedTensorDescriptor, so we don't need to
// count on compiler to optimize way those register memory for us template <typename TensorDesc>
Array<index_t, nOriginalDim> mOriginalIndex; struct TensorCoordinate
Array<index_t, nDim> mPartialOffsets; {
private:
// complete offset template <typename... Ts>
index_t mOffset; __host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
{
return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>(
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
}
template <typename... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
{
return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>(
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
}
public:
using type = decltype(MakeDummyTensorCoordinate(TensorDesc{}));
}; };
} // namespace ck } // namespace ck
......
#ifndef CK_TENSOR_COORDINATE_DEPRECATED_HPP
#define CK_TENSOR_COORDINATE_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
namespace ck {
// TensorDesc is ConstantTensorDescriptor_deprecated
template <class TensorDesc>
struct NormalTensorCoordinate_deprecated
{
using type = NormalTensorCoordinate_deprecated;
using tensor_desc_type = TensorDesc;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
__host__
__device__ constexpr NormalTensorCoordinate_deprecated(Array<index_t, nDim> tensor_index)
: mOffset{tensor_desc_type::GetOffsetFromMultiIndex(tensor_index)}
{
}
template <class... Xs>
__host__ __device__ constexpr NormalTensorCoordinate_deprecated(Xs... xs)
: NormalTensorCoordinate_deprecated(Array<index_t, nDim>{xs...})
{
}
template <index_t... Xs>
__host__ __device__ constexpr NormalTensorCoordinate_deprecated(Sequence<Xs...>)
: NormalTensorCoordinate_deprecated(Array<index_t, nDim>{Xs...})
{
}
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
// T is Array or Sequence
template <class T>
__host__ __device__ type operator+=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
mOffset += tensor_desc_type::GetOffsetFromMultiIndex(step_sizes);
return *this;
}
template <class T>
__host__ __device__ type operator-=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
mOffset -= tensor_desc_type::GetOffsetFromMultiIndex(step_sizes);
return *this;
}
template <class T>
__host__ __device__ constexpr type operator+(T step_sizes) const
{
type coord = *this;
coord += step_sizes;
return coord;
}
template <class T>
__host__ __device__ constexpr type operator-(T step_sizes) const
{
type coord = *this;
coord -= step_sizes;
return coord;
}
// reposition point of origin, and return compensated offset.
// This is a hack to reduce index calculation during looping over
// a tensor whose origin is this TensorCoordinate. It does so, by spitting
// out the run-time offset to the pointer (to the tensor data) held by this
// TensorCoordiante, so the caller can add the offset into the run-time pointer of
// the data, so only 1 run-time variable (update pointer) is needed, instead
// of 2 run-time variables (old pointer and this offset)
// TODO: after introducing the concept of "run-time tensor view", which contains the
// run-time pointer to the data, always keep track of the pointer, instead of both
// offset and the pointer. This also bring additional benefit that we don't need to
// worry the offset might underflow (because offset is unsigned integer) when updating it.
__host__ __device__ constexpr index_t RepositionOrigin()
{
index_t offset_diff = mOffset;
mOffset = 0;
return offset_diff;
}
private:
index_t mOffset;
};
// TensorDesc is ConstantMergedTensorDescriptor_deprecated
template <class TensorDesc>
struct MergedTensorCoordinate_deprecated
{
using type = MergedTensorCoordinate_deprecated;
using tensor_desc_type = TensorDesc;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
static constexpr index_t nOriginalDim =
tensor_desc_type::GetOriginalTensorDescriptor().GetNumOfDimension();
__host__
__device__ constexpr MergedTensorCoordinate_deprecated(Array<index_t, nDim> tensor_index)
: mOriginalIndex{tensor_desc_type::GetOriginalMultiIndexFromMultiIndex(tensor_index)}
{
// partial offset on each dimension
static_for<0, nDim, 1>{}([&](auto idim) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
mPartialOffsets(idim) = partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mOriginalIndex, partial_original_dims));
});
// complete offset
mOffset =
accumulate_on_array(mPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
}
template <class... Xs>
__host__ __device__ constexpr MergedTensorCoordinate_deprecated(Xs... xs)
: MergedTensorCoordinate_deprecated(Array<index_t, nDim>{xs...})
{
}
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
template <class IDim, class T, bool PositiveDirection>
__host__ __device__ void
MoveOnDimension(IDim idim_, T step_size, integral_constant<bool, PositiveDirection>)
{
constexpr auto idim = idim_;
// if step_size is known at compile time
static_if<is_static<T>::value>{}(
[&](auto) { static_if<T{} == 0>{}([&](auto) { return; }); });
// update original index
static_if<tensor_desc_type::ContainMultipleOriginalDimensions(idim)>{}([&](auto) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr index_t ndim_partial_original = partial_original_dims.GetSize();
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
const auto partial_original_step_sizes =
partial_original_desc.GetMultiIndexFrom1dIndex(step_size);
// update partial original multi-id
auto partial_original_id = extract_array(mOriginalIndex, partial_original_dims);
static_if<PositiveDirection>{}([&](auto) {
partial_original_id += partial_original_step_sizes;
bool carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original - 1, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(carry)
{
++partial_original_id(i);
}
carry = false;
if(partial_original_id[i] >= partial_original_desc.GetLength(i))
{
partial_original_id(i) -= partial_original_desc.GetLength(i);
carry = true;
}
});
// highest dimension
if(carry)
{
++partial_original_id(0);
}
}).Else([&](auto) {
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
partial_original_id +=
partial_original_desc.GetLengths() - partial_original_step_sizes;
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original - 1, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(borrow)
{
--partial_original_id(i);
}
borrow = false;
if(partial_original_id[i] < partial_original_desc.GetLength(i))
{
partial_original_id(i) += partial_original_desc.GetLength(i);
borrow = true;
}
});
// highest dimension
if(borrow)
{
--partial_original_id(0);
}
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
partial_original_id = partial_original_id - partial_original_desc.GetLengths();
});
// update "mOriginalIndex"
static_for<0, ndim_partial_original, 1>{}([&](auto I) {
constexpr auto idim_original = partial_original_dims[I];
mOriginalIndex(idim_original) = partial_original_id[I];
});
// calculate new partial offset on this merged dimension
const index_t old_partial_offset = mPartialOffsets[idim];
mPartialOffsets(idim) =
partial_original_desc.GetOffsetFromMultiIndex(partial_original_id);
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mOffset = (mOffset + mPartialOffsets[idim]) - old_partial_offset;
}).Else([&](auto fwd) {
static_if<PositiveDirection>{}([&](auto) {
mOffset += step_size * fwd(tensor_desc_type{}).GetStride(idim);
}).Else([&](auto) { mOffset -= step_size * fwd(tensor_desc_type{}).GetStride(idim); });
});
}
// T is Array or Sequence
template <class T>
__host__ __device__ type operator+=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
static_for<0, nDim, 1>{}([&](auto idim) {
// compiler should remove dead code path, because step_sizes is known at
// compile time
if(step_sizes[idim] != 0)
{
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, true>{});
}
});
return *this;
}
template <class T>
__host__ __device__ type operator-=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
static_for<0, nDim, 1>{}([&](auto idim) {
// compiler should remove dead code path, because step_sizes is known at
// compile time
if(step_sizes[idim] != 0)
{
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, false>{});
}
});
return *this;
}
template <class T>
__host__ __device__ constexpr type operator+(T step_sizes) const
{
type coord = *this;
coord += step_sizes;
return coord;
}
template <class T>
__host__ __device__ constexpr type operator-(T step_sizes) const
{
type coord = *this;
coord -= step_sizes;
return coord;
}
__host__ __device__ static constexpr index_t RepositionOrigin() { return 0; }
private:
// Allocate register memory for all merged dimensions and normal dimensions.
// However, only those merged dimensions, whose index will be involved in arithmetic
// after the construction of this TensorCoordinate (e.g. when user move a slicing
// window on the merged dimension), will use these register memory.
// Let's hope compiler will optimize away those register memory allocated for normal
// dimensions, and those merged dimensions, that would never be involved in index
// arithmetic after construction of TensorCoordinate.
// TODO: refactor TensorCoordinate, after introducing the concept of "dimensions"
// and simplify implementation of ConstantMergedTensorDescriptor_deprecated, so we don't need to
// count on compiler to optimize away those register memory for us
Array<index_t, nOriginalDim> mOriginalIndex;
Array<index_t, nDim> mPartialOffsets;
// complete offset
index_t mOffset;
};
template <class TensorDesc>
struct TensorCoordinate_deprecated
{
private:
template <class... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated<Ts...>)
{
return NormalTensorCoordinate_deprecated<ConstantTensorDescriptor_deprecated<Ts...>>();
}
template <class... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated<Ts...>)
{
return MergedTensorCoordinate_deprecated<
ConstantMergedTensorDescriptor_deprecated<Ts...>>();
}
public:
using type = decltype(MakeDummyTensorCoordinate(TensorDesc{}));
};
} // namespace ck
#endif
#ifndef CK_TENSOR_COORDINATE_HELPER_HPP
#define CK_TENSOR_COORDINATE_HELPER_HPP
#include "tensor_coordiante_hpp"
namespace ck {
template <typename TensorDesc>
__host__ __device__ constexpr auto
make_tensor_coordinate(TensorDesc, MultiIndex<TensorDesc::GetNumOfDimension()> idx)
{
return typename TensorCoordinate<TensorDesc>::type(idx);
}
} // namespace ck
#endif
#ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "dimension.hpp"
#include "multi_index_transform.hpp"
namespace ck {
// tensor descriptor for "native tensor"
// A "native tensor" is a "true" tensor that can be represented by Lengths and Strides
template <typename... NativeDimensions>
struct NativeTensorDescriptor
{
using type = NativeTensorDescriptor;
static constexpr index_t nDim = sizeof...(NativeDimensions);
static constexpr auto mDimensions = make_tuple(NativeDimensions{}...);
using Index = MultiIndex<nDim>;
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{
return mDimensions.At(Number<IDim>{}).GetLength();
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
{
return mDimensions.At(Number<IDim>{}).GetStride();
}
template <index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
{
return Sequence<GetLength(Number<IDims>{})...>{};
}
template <index_t... IDims>
__host__ __device__ static constexpr auto GetStrides(Sequence<IDims...>)
{
return Sequence<GetStride(Number<IDims>{})...>{};
}
template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
{
return GetLengths(Sequence<IDim, IDims...>{});
}
template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetStrides(Number<IDim>, Number<IDims>...)
{
return GetStrides(Sequence<IDim, IDims...>{});
}
__host__ __device__ static constexpr auto GetLengths()
{
return GetLengths(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
}
__host__ __device__ static constexpr auto GetStrides()
{
return GetStrides(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
}
__host__ __device__ static constexpr index_t GetElementSize()
{
return reduce_on_sequence(GetLengths(), math::multiplies<index_t>{}, Number<1>{});
}
__host__ __device__ static constexpr index_t GetElementSpace()
{
return reduce_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
}
// TODO: this cannot return constepxr because of use of lambda
__host__ __device__ static constexpr index_t CalculateOffset(const Index& idx)
{
index_t offset = 0;
static_for<0, nDim, 1>{}([&](auto idim) { offset += idx[idim] * GetStride(idim); });
return offset;
}
__host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff)
{
index_t offset_diff = 0;
static_for<0, nDim, 1>{}(
[&](auto idim) { offset_diff += idx_diff[idim] * GetStride(idim); });
return offset_diff;
}
template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
{
return true;
}
__host__ __device__ static constexpr auto GetLinearDimensionMask()
{
return typename uniform_sequence_gen<nDim, 1>::type{};
}
__host__ __device__ static constexpr auto GetNonLinearDimensionMask()
{
return typename uniform_sequence_gen<nDim, 0>::type{};
}
__host__ __device__ static constexpr auto GetNonLinearDimensions() { return Sequence<>{}; }
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
{
return Tuple<>{};
}
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidOffset(const Index& /* idx */)
{
return true;
}
};
// Tensor descriptor for "transformed tensor"
template <typename LowTensorDescriptor, // NativeTensorDescriptor or TransformedTensorDescriptor
typename Transforms, // Tuple<MultIndexTransforms...>
typename LowDimensionIds, // Tuple<Sequence<...>>
typename UpDimensionIds> // Tuple<Sequence<...>>
struct TransformedTensorDescriptor
{
using type = TransformedTensorDescriptor;
static constexpr index_t nTransform = Transforms::Size();
struct lambda_merge_sequences
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
{
return merge_sequences(seqs...);
}
};
__host__ __device__ static constexpr auto GetNumOfLowerDimension()
{
// Here, we assume all lower-dimensions are active
// TODO: sanity-check all lower-dimension are indeed active
using duplicated_low_active_dims =
decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{}));
using low_active_dims = typename sequence_unique_sort<duplicated_low_active_dims,
math::less<index_t>,
math::equal<index_t>>::type;
return low_active_dims::Size();
}
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
{
using duplicated_up_active_dims =
decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{}));
using up_active_dims = typename sequence_unique_sort<duplicated_up_active_dims,
math::less<index_t>,
math::equal<index_t>>::type;
return up_active_dims::Size();
}
static constexpr index_t nDimUp = GetNumOfUpperDimension();
static constexpr index_t nDimLow = GetNumOfLowerDimension();
using UpperIndex = MultiIndex<nDimUp>;
using LowerIndex = MultiIndex<nDimLow>;
__host__ __device__ constexpr TransformedTensorDescriptor()
{
static_assert(nTransform == Transforms::Size() && nTransform == LowDimensionIds::Size() &&
nTransform == UpDimensionIds::Size(),
"wrong! # of transformations not the same");
// sanity check:
// LowDimensionIds should include all low-dimensions,
// UpDimensionIds should include all up-dimensions
using mingled_up_dimension_ids =
decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{}));
using sorted_up_dimension_ids =
typename sequence_sort<mingled_up_dimension_ids, math::less<index_t>>::type;
static_assert(sorted_up_dimension_ids::Size() == nDimUp &&
is_valid_sequence_map<sorted_up_dimension_ids>{},
"wrong! UpDimensionIds is not configured correctly");
using mingled_low_dimension_ids =
decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{}));
using sorted_low_dimension_ids =
typename sequence_sort<mingled_low_dimension_ids, math::less<index_t>>::type;
static_assert(sorted_low_dimension_ids::Size() == nDimLow &&
is_valid_sequence_map<sorted_low_dimension_ids>{},
"wrong! LowDimensionIds is not configured correctly");
// TODO: sanity check: while a up-dimension could be associated with multille
// transformation, a low-dimension should be associated with only one transformation
// TODO: sanity-check: GetLowerLengths of each transform should be consistent with lengths
// of lower-tensor-descriptor
}
__host__ __device__ static constexpr auto GetNumOfDimension()
{
return GetNumOfUpperDimension();
}
__host__ __device__ static constexpr auto GetLowerTensorDescriptor()
{
return LowTensorDescriptor{};
}
struct lambda_GetUpperLengths
{
template <typename Transform>
__host__ __device__ constexpr auto operator()(const Transform& tran) const
{
return tran.GetUpperLengths();
}
};
__host__ __device__ static constexpr auto GetUpperLengths()
{
constexpr auto tuple_of_up_lengths =
transform_tuples(lambda_GetUpperLengths{}, Transforms{});
constexpr auto mingled_up_lengths = unpack(lambda_merge_sequences{}, tuple_of_up_lengths);
constexpr auto mingled_up_dimension_ids =
unpack(lambda_merge_sequences{}, UpDimensionIds{});
// TODO: sanity-check mingled_up_dimension_ids contain all upper-dimensions
// TODO: sanity-check mingled_up_lengths have no conflicting upper-length
// sort by upper-dimension-ids
using sort_up_dimension_ids = sequence_unique_sort<decltype(mingled_up_dimension_ids),
math::less<index_t>,
math::equal<index_t>>;
// sanity-check sorted-upper-dimension-ids should be Sequence<0, 1, ... nDimUp-1>
static_assert(is_same<typename sort_up_dimension_ids::type,
typename arithmetic_sequence_gen<0, nDimUp, 1>::type>{},
"wrong! UpDimensionIds is not configured correctly");
constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{};
constexpr auto sorted_up_lengths =
pick_sequence_elements_by_ids(mingled_up_lengths, sorted2unsorted_map);
return sorted_up_lengths;
}
__host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); }
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{
return GetLengths()[IDim];
}
template <index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
{
return Sequence<GetLength(Number<IDims>{})...>{};
}
template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
{
return GetLengths(Sequence<IDim, IDims...>{});
}
__host__ __device__ static constexpr index_t GetElementSize()
{
return reduce_on_sequence(GetLengths(), math::multiplies<index_t>{}, Number<1>{});
}
__host__ __device__ static constexpr index_t GetElementSpace()
{
// TODO: Is this the correct definition for transformed tensor?
return GetLowerTensorDescriptor().GetElementSpace();
}
// TODO: right now return value is not constexpr because use of non-constexpr lambda
__host__ __device__ static constexpr LowerIndex CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
auto idx_low_part = pick_array_element(idx_low, LowDimensionIds{}.At(itran));
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_part = tran.CalculateLowerIndex(to_array(idx_up_part));
});
return idx_low;
}
// TODO: right now return value is not constexpr because use of non-constepxr lambda
__host__ __device__ static constexpr LowerIndex CalculateLowerIndexDiff(
const UpperIndex& idx_up_diff, const UpperIndex& idx_up_old, const LowerIndex& idx_low_old)
{
LowerIndex idx_low_diff;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_diff_part =
pick_array_element(idx_up_diff, UpDimensionIds{}.At(itran));
const auto idx_up_old_part = pick_array_element(idx_up_old, UpDimensionIds{}.At(itran));
const auto idx_low_old_part =
pick_array_element(idx_low_old, LowDimensionIds{}.At(itran));
auto idx_low_diff_part = pick_array_element(idx_low_diff, LowDimensionIds{}.At(itran));
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_diff_part = tran.CalculateLowerIndexDiff(
to_array(idx_up_diff_part), to_array(idx_up_old_part), to_array(idx_low_old_part));
});
return idx_low_diff;
}
__host__ __device__ static constexpr index_t CalculateOffset(const UpperIndex& idx_up)
{
return GetLowerTensorDescriptor().CalculateOffset(CalculateLowerIndex(idx_up));
}
struct lambda_sequence_logical_and
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs...) const
{
return typename sequence_reduce<logical_and<index_t>, Seqs...>::type{};
}
};
template <typename T>
struct lambda_is_true
{
__host__ __device__ constexpr auto operator()(const T& x) const
{
// TODO: remove static_cast once Sequence can take bool as entries
return static_cast<bool>(x) == true;
}
};
struct lambda_get_linear_dimension_mask_of_single_tranform
{
// check only one transform at a time
template <typename Transform, typename LowDimensionId, typename UpDimensionId>
__host__ __device__ constexpr auto
operator()(Transform, LowDimensionId, UpDimensionId) const
{
// judge if transformation is linear
constexpr bool is_linear_transform = Transform::IsLinearTransform();
// judge if all lower dimension are linear
constexpr bool are_all_low_dim_linear = sequence_all_of(
pick_sequence_elements_by_ids(GetLowerTensorDescriptor().GetLinearDimensionMask(),
LowDimensionId{}),
lambda_is_true<index_t>{});
// create linear mask for upper dimensions
constexpr bool are_up_dim_linear = is_linear_transform && are_all_low_dim_linear;
constexpr auto mask_of_up_linear_dims = modify_sequence_elements_by_ids(
typename uniform_sequence_gen<nDimUp, 1>::type{},
typename uniform_sequence_gen<UpDimensionId::Size(), are_up_dim_linear>::type{},
UpDimensionId{});
return mask_of_up_linear_dims;
}
};
// TODO: this is a hack, transform_tuples() doesn't compile, would complain about constexpr
template <typename F, typename X, typename Y, typename Z, index_t... Is>
__host__ __device__ static constexpr auto
dummy_transform_tuples_impl(F f, X x, Y y, Z z, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
}
__host__ __device__ static constexpr auto GetLinearDimensionMask()
{
#if 0
// create tuple of linear dimension masks, for all transformations
// TODO: this doesn't compile, because transform_tuples() complain about constexpr
constexpr auto tuple_of_linear_dimension_mask =
transform_tuples(lambda_get_linear_dimension_mask_of_single_tranform{},
Transforms{},
LowDimensionIds{},
UpDimensionIds{});
#else
// create tuple of linear dimension masks, for all transformations
// TODO: this is a hack
constexpr auto tuple_of_linear_dimension_mask = dummy_transform_tuples_impl(
lambda_get_linear_dimension_mask_of_single_tranform{},
Transforms{},
LowDimensionIds{},
UpDimensionIds{},
typename arithmetic_sequence_gen<0, Transforms::Size(), 1>::type{});
#endif
// reduce tuple of masks into one mask
constexpr auto linear_dimension_mask =
unpack(lambda_sequence_logical_and{}, tuple_of_linear_dimension_mask);
return linear_dimension_mask;
}
__host__ __device__ static constexpr auto GetNonLinearDimensionMask()
{
return GetLinearDimensionMask().Transform(logical_not<index_t>{});
}
template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
{
return GetLinearDimensionMask().At(Number<IDim>{});
}
__host__ __device__ static constexpr auto GetLinearDimensions()
{
constexpr auto linear_dimension_mask = GetLinearDimensionMask();
return pick_sequence_elements_by_mask(
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, linear_dimension_mask);
}
__host__ __device__ static constexpr auto GetNonLinearDimensions()
{
constexpr auto nonlinear_dimension_mask = GetNonLinearDimensionMask();
return pick_sequence_elements_by_mask(
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, nonlinear_dimension_mask);
}
#if 0
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
{
// TODO: not implemented
}
#endif
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up)
{
bool flag = true;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
flag = flag && tran.IsUpperIndexMappedToValidLowerIndex(to_array(idx_up_part));
});
return flag;
}
// Whenever this function is called, it will call CalculateLowerIndex() recursively.
// If you have created a tensor coordinate already, instead of calling this function,
// you should call TensorCoordinate::IsUpperIndexMappedToValidOffset() which would
// be less expensive.
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidOffset(const UpperIndex& idx_up)
{
return IsUpperIndexMappedToValidLowerIndex(idx_up) &&
GetLowerTensorDescriptor().IsUpperIndexMappedToValidOffset(
CalculateLowerIndex(idx_up));
}
};
} // namespace ck
#endif
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
namespace ck {
template <typename Lengths>
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
{
return reverse_inclusive_scan_sequence(
Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
}
template <typename Lengths, index_t Align>
__host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number<Align>)
{
constexpr index_t L_back_align =
Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
return calculate_tensor_strides_packed(
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
}
template <index_t... Lengths, index_t... Strides>
__host__ __device__ constexpr auto make_native_tensor_descriptor(Sequence<Lengths...>,
Sequence<Strides...>)
{
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
}
template <typename Lengths>
__host__ __device__ constexpr auto make_native_tensor_descriptor_packed(Lengths)
{
constexpr auto strides = calculate_tensor_strides_packed(Lengths{});
return make_native_tensor_descriptor(Lengths{}, strides);
}
template <typename Lengths, index_t Align>
__host__ __device__ constexpr auto make_native_tensor_descriptor_aligned(Lengths, Number<Align>)
{
constexpr auto strides = calculate_tensor_strides_aligned(Lengths{}, Number<Align>{});
return make_native_tensor_descriptor(Lengths{}, strides);
}
template <typename LowTensorDescriptor,
typename Transforms,
typename LowDimensionIds,
typename UpDimensionIds>
__host__ __device__ constexpr auto
transform_tensor_descriptor(LowTensorDescriptor, Transforms, LowDimensionIds, UpDimensionIds)
{
return TransformedTensorDescriptor<LowTensorDescriptor,
Transforms,
LowDimensionIds,
UpDimensionIds>{};
}
template <typename LowerTensorDescriptor,
index_t... LowerLengths,
index_t... LowerDimensionIds,
index_t... UpperDimensionIds>
__host__ __device__ constexpr auto
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
Sequence<LowerLengths...>,
Sequence<LowerDimensionIds...>,
Sequence<UpperDimensionIds...>)
{
return TransformedTensorDescriptor<LowerTensorDescriptor,
Tuple<PassThrough<LowerLengths>...>,
Tuple<Sequence<LowerDimensionIds>...>,
Tuple<Sequence<UpperDimensionIds>...>>{};
}
// reorder a NativeTensorDescriptor
template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
{
static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map");
constexpr auto old_desc = NativeTensorDescriptor<Ts...>{};
static_assert(old_desc.GetNumOfDimension() == MapLower2Upper::Size(), "wrong!");
constexpr auto new_lengths = old_desc.GetLengths().ReorderGivenOld2New(MapLower2Upper{});
constexpr auto new_strides = old_desc.GetStrides().ReorderGivenOld2New(MapLower2Upper{});
return make_native_tensor_descriptor(new_lengths, new_strides);
}
// reorder a TransformedTensorDescriptor
template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
{
static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map");
constexpr auto low_desc = TransformedTensorDescriptor<Ts...>{};
static_assert(low_desc.GetNumOfDimension() == MapLower2Upper::Size(), "wrong!");
return reorder_transformed_tensor_descriptor_impl(
low_desc,
low_desc.GetLengths(),
typename arithmetic_sequence_gen<0, low_desc.GetNumOfDimension(), 1>::type{},
MapLower2Upper{});
}
template <typename LowerTensorDescriptor, typename MapUpper2Lower>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_upper2lower(LowerTensorDescriptor, MapUpper2Lower)
{
return reorder_tensor_descriptor_given_lower2upper(
LowerTensorDescriptor{}, typename sequence_map_inverse<MapUpper2Lower>::type{});
}
template <typename Lengths, typename Strides>
__host__ __device__ constexpr bool are_dimensions_unfoldable(Lengths, Strides)
{
static_assert(Lengths::Size() == Strides::Size(), "wrong!");
bool flag = true;
for(index_t i = 0; i < Lengths::Size() - 1; ++i)
{
flag = flag && Strides::At(i) == Strides::At(i + 1) * Lengths::At(i + 1);
}
return flag;
}
// unfold only support NativeTennsorDescriptor, for now
template <index_t FirstUnfoldDim, index_t LastUnfoldDim, typename... Ts>
__host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescriptor<Ts...> desc,
Number<FirstUnfoldDim>,
Number<LastUnfoldDim>)
{
constexpr index_t nDim = desc.GetNumOfDimension();
static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim && FirstUnfoldDim <= LastUnfoldDim,
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!");
// left and right
constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::type{};
constexpr auto middle =
typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::type{};
constexpr auto right = typename arithmetic_sequence_gen<LastUnfoldDim + 1, nDim, 1>::type{};
// sanity-checknfoldable
static_assert(are_dimensions_unfoldable(desc.GetLengths(middle), desc.GetStrides(middle)),
"wrong! not unfoldable");
// unfolded length, stride
constexpr index_t unfold_length =
reduce_on_sequence(desc.GetLengths(middle), math::multiplies<index_t>{}, Number<1>{});
constexpr index_t unfold_stride = desc.GetStride(Number<LastUnfoldDim>{});
// new lengths, strides
constexpr auto new_lengths =
desc.GetLengths(left).PushBack(Number<unfold_length>{}).PushBack(desc.GetLengths(right));
constexpr auto new_strides =
desc.GetStrides(left).PushBack(Number<unfold_stride>{}).PushBack(desc.GetStrides(right));
return make_native_tensor_descriptor(new_lengths, new_strides);
}
// a cluster map 1d index to N-d index
template <typename Lengths, typename ArrangeOrder>
struct ClusterDescriptor
{
static constexpr index_t nDim = Lengths::Size();
static constexpr auto mDesc = transform_tensor_descriptor(
make_native_tensor_descriptor_packed(Lengths{}),
make_tuple(Merge<decltype(Lengths::ReorderGivenNew2Old(ArrangeOrder{}))>{}),
make_tuple(ArrangeOrder{}),
make_tuple(Sequence<0>{}));
__host__ __device__ constexpr ClusterDescriptor()
{
static_assert(Lengths::Size() == nDim && ArrangeOrder::Size() == nDim,
"wrong! size not the same");
static_assert(is_valid_sequence_map<ArrangeOrder>{}, "wrong! ArrangeOrder is wrong");
}
__host__ __device__ static constexpr index_t GetElementSize() { return mDesc.GetElementSize(); }
__host__ __device__ static constexpr auto CalculateClusterIndex(index_t idx_1d)
{
return mDesc.CalculateLowerIndex(MultiIndex<1>{idx_1d});
}
};
template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor(
Lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
{
return ClusterDescriptor<Lengths, decltype(order)>{};
}
} // namespace ck
#endif
...@@ -5,25 +5,23 @@ ...@@ -5,25 +5,23 @@
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp" #include "threadwise_gemm.hpp"
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 0
#endif
namespace ck { namespace ck {
// if following number are power of 2, index calculation shall be greatly reduced: // blockwise GEMM: C += transpose(A) * B
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster // A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize, template <index_t BlockSize,
index_t EPack, typename BlockMatrixA,
class BlockMatrixA, typename BlockMatrixB,
class BlockMatrixB, typename ThreadMatrixC,
class ThreadMatrixC,
index_t MPerThreadSubC, index_t MPerThreadSubC,
index_t NPerThreadSubC, index_t NPerThreadSubC,
index_t MLevel0Cluster, index_t MLevel0ThreadCluster,
index_t NLevel0Cluster, index_t NLevel0ThreadCluster,
index_t MLevel1Cluster, index_t MLevel1ThreadCluster,
index_t NLevel1Cluster, index_t NLevel1ThreadCluster,
index_t KPerThreadLoop, index_t KPerThreadLoop,
index_t DataPerReadA, index_t DataPerReadA,
index_t DataPerReadB> index_t DataPerReadB>
...@@ -40,8 +38,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -40,8 +38,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2() __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
{ {
constexpr index_t ThreadPerLevel1Cluster = constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
...@@ -51,8 +49,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -51,8 +49,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol(); constexpr index_t N = BlockMatrixB::NCol();
static_assert(M % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 && static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
N % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0, N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
"wrong! Cannot evenly divide work among\n"); "wrong! Cannot evenly divide work among\n");
static_assert( static_assert(
...@@ -70,26 +68,28 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -70,26 +68,28 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol(); constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t MRepeat = M / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster); constexpr index_t MRepeat =
constexpr index_t NRepeat = N / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster); M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
constexpr index_t NRepeat =
N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster);
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{}; return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
} }
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{ {
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster; constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster;
index_t level1_id = thread_id / ThreadPerLevel0Cluster; index_t level1_id = thread_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1Cluster; index_t level1_m_id = level1_id / NLevel1ThreadCluster;
index_t level1_n_id = level1_id % NLevel1Cluster; index_t level1_n_id = level1_id % NLevel1ThreadCluster;
index_t level0_id = thread_id % ThreadPerLevel0Cluster; index_t level0_id = thread_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0Cluster; index_t level0_m_id = level0_id / NLevel0ThreadCluster;
index_t level0_n_id = level0_id % NLevel0Cluster; index_t level0_n_id = level0_id % NLevel0ThreadCluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster; constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster; constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
...@@ -100,8 +100,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -100,8 +100,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{ {
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t MPerLevel1Cluster =
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
index_t m_repeat = m_in_c / MPerThreadSubC; index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC; index_t n_repeat = n_in_c / NPerThreadSubC;
...@@ -113,168 +115,27 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -113,168 +115,27 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
n_repeat * NPerLevel1Cluster + n_in_sub_c}; n_repeat * NPerLevel1Cluster + n_in_sub_c};
} }
#if CK_USE_AMD_INLINE_ASM template <typename FloatA, typename FloatB, typename FloatC>
// 1 in specialized template represent pack size. fp32 = 1
template <index_t PACKSIZE>
__device__ void outerProduct(const typename vector_type<float, 4>::MemoryType& a,
const typename vector_type<float, 4>::MemoryType& b,
typename vector_type<float, 4>::MemoryType* c) const
{
static_assert(1 == PACKSIZE, "only packsize of 1 is supported with float datatype!");
constexpr index_t NRepeat = 2;
outerProduct1x4(a.x, b, c[0 * NRepeat]);
outerProduct1x4(a.y, b, c[1 * NRepeat]);
outerProduct1x4(a.z, b, c[2 * NRepeat]);
outerProduct1x4(a.w, b, c[3 * NRepeat]);
}
// 1 in specialized template represent pack size. fp32 = 1
template <index_t PACKSIZE>
__device__ void outerProduct(const typename vector_type<float, 2>::MemoryType& a,
const typename vector_type<float, 4>::MemoryType& b,
typename vector_type<float, 4>::MemoryType* c) const
{
static_assert(1 == PACKSIZE, "only packsize of 1 is supported with float datatype!");
constexpr index_t NRepeat = 2;
outerProduct1x4(a.x, b, c[0 * NRepeat]);
outerProduct1x4(a.y, b, c[1 * NRepeat]);
}
// 1 in specialized template represent pack size. fp32 = 1
template <index_t PACKSIZE>
__device__ void outerProduct(const typename vector_type<float, 4>::MemoryType& a,
const typename vector_type<float, 2>::MemoryType& b,
typename vector_type<float, 2>::MemoryType* c) const
{
static_assert(1 == PACKSIZE, "only packsize of 1 is supported with float datatype!");
constexpr index_t NRepeat = 2;
outerProduct1x2(a.x, b, c[0 * NRepeat]);
outerProduct1x2(a.y, b, c[1 * NRepeat]);
outerProduct1x2(a.z, b, c[2 * NRepeat]);
outerProduct1x2(a.w, b, c[3 * NRepeat]);
}
// 1 in specialized template represent pack size. fp32 = 1
template <index_t PACKSIZE>
__device__ void outerProduct(const typename vector_type<float, 2>::MemoryType& a,
const typename vector_type<float, 2>::MemoryType& b,
typename vector_type<float, 2>::MemoryType* c) const
{
static_assert(1 == PACKSIZE, "only packsize of 1 is supported with float datatype!");
constexpr index_t NRepeat = 2;
outerProduct1x2(a.x, b, c[0 * NRepeat]);
outerProduct1x2(a.y, b, c[1 * NRepeat]);
}
// PACKSIZE for fp16 could be 4 or 2
template <index_t PACKSIZE>
__device__ void __device__ void
outerProduct(const typename vector_type<typename vector_type<half, PACKSIZE>::MemoryType, Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
4>::MemoryType& a,
const typename vector_type<typename vector_type<half, PACKSIZE>::MemoryType,
4>::MemoryType& b,
typename vector_type<float, 4>::MemoryType* c) const
{
static_assert(2 == PACKSIZE || 4 == PACKSIZE,
"only packsize of 2,4 is supported with float datatype!");
constexpr index_t NRepeat = 2;
const typename vector_type<half, PACKSIZE>::MemoryType* reg_a =
reinterpret_cast<const typename vector_type<half, PACKSIZE>::MemoryType*>(&a);
outerProduct1x4Half<PACKSIZE>(reg_a[0], b, c[0 * NRepeat]);
outerProduct1x4Half<PACKSIZE>(reg_a[1], b, c[1 * NRepeat]);
outerProduct1x4Half<PACKSIZE>(reg_a[2], b, c[2 * NRepeat]);
outerProduct1x4Half<PACKSIZE>(reg_a[3], b, c[3 * NRepeat]);
}
// PACKSIZE for fp16 could be 4 or 2
template <index_t PACKSIZE>
__device__ void
outerProduct(const typename vector_type<typename vector_type<half, PACKSIZE>::MemoryType,
2>::MemoryType& a,
const typename vector_type<typename vector_type<half, PACKSIZE>::MemoryType,
2>::MemoryType& b,
typename vector_type<float, 2>::MemoryType* c) const
{
static_assert(2 == PACKSIZE || 4 == PACKSIZE,
"only packsize of 2,4 is supported with float datatype!");
constexpr index_t NRepeat = 2;
const typename vector_type<half, PACKSIZE>::MemoryType* reg_a =
reinterpret_cast<const typename vector_type<half, PACKSIZE>::MemoryType*>(&a);
outerProduct1x2Half<PACKSIZE>(reg_a[0], b, c[0 * NRepeat]);
outerProduct1x2Half<PACKSIZE>(reg_a[1], b, c[1 * NRepeat]);
}
// PACKSIZE for fp16 could be 4 or 2
template <index_t PACKSIZE>
__device__ void
outerProduct1x4Half(const typename vector_type<half, PACKSIZE>::MemoryType& a,
const typename vector_type<typename vector_type<half, PACKSIZE>::MemoryType,
4>::MemoryType& b,
vector_type<float, 4>::MemoryType& c) const
{
static_if<PACKSIZE == 4>{}([&](auto) {
outerProduct1x4dot2TwoTimes(reinterpret_cast<const half2*>(&a),
reinterpret_cast<const half2*>(&b),
reinterpret_cast<float*>(&c));
}).Else([&](auto) {
static_if<PACKSIZE == 2>{}([&](auto) {
outerProduct1x4dot2(reinterpret_cast<const half2*>(&a),
reinterpret_cast<const half2*>(&b),
reinterpret_cast<float*>(&c));
}).Else([&](auto fwd) {
// not implemented
static_assert(fwd(false), "wrong! packsize = 1 for fp16 is insensible.");
});
});
}
// PACKSIZE for fp16 could be 4 or 2
template <index_t PACKSIZE>
__device__ void
outerProduct1x2Half(const typename vector_type<half, PACKSIZE>::MemoryType& a,
const typename vector_type<typename vector_type<half, PACKSIZE>::MemoryType,
2>::MemoryType& b,
vector_type<float, 2>::MemoryType& c) const
{
static_if<PACKSIZE == 4>{}([&](auto) {
outerProduct1x2dot2TwoTimes(reinterpret_cast<const half2*>(&a),
reinterpret_cast<const half2*>(&b),
reinterpret_cast<float*>(&c));
}).Else([&](auto) {
static_if<PACKSIZE == 2>{}([&](auto) {
outerProduct1x2dot2(reinterpret_cast<const half2*>(&a),
reinterpret_cast<const half2*>(&b),
reinterpret_cast<float*>(&c));
}).Else([&](auto fwd) {
// not implemented
static_assert(fwd(false), "wrong! packsize = 1 for fp16 is insensible.");
});
});
}
template <class FloatA, class FloatB, class FloatC>
__device__ void Run_amd_asm(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{ {
constexpr auto a_block_mtx = BlockMatrixA{}; constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow(); constexpr index_t K = a_block_mtx.NRow();
constexpr index_t MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// thread A, B for GEMM // thread A, B for GEMM
constexpr auto a_thread_mtx = constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{}); make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
...@@ -282,66 +143,58 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -282,66 +143,58 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_thread_mtx = constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{}); make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
static_assert((MPerThreadSubC == 4 || MPerThreadSubC == 2) &&
(NPerThreadSubC == 4 || NPerThreadSubC == 2) && KPerThreadLoop == 1,
"M/NPerThreadSubC wrong!");
static_assert(MPerThread % 4 == 0 && NPerThread % 4 == 0, "M/NPerThread % 4 != 0");
constexpr index_t MRepeat = M / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
constexpr index_t NRepeat = N / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
static_assert(MRepeat == 2 && NRepeat == 2, "M/NRepeat != 2");
using typeA = typename vector_type<FloatA, MPerThreadSubC>::MemoryType;
using typeB = typename vector_type<FloatB, NPerThreadSubC>::MemoryType;
using typeC = typename vector_type<FloatC, NPerThreadSubC>::MemoryType;
FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
typeA* reg_a = reinterpret_cast<typeA*>(p_a_thread); constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixA,
typeB* reg_b = reinterpret_cast<typeB*>(p_b_thread); decltype(a_thread_mtx),
typeC* reg_c = reinterpret_cast<typeC*>(p_c_thread); KPerThreadLoop,
MPerThreadSubC,
reg_a[0] = *reinterpret_cast<const typeA*>(&p_a_block[mMyThreadOffsetA]); DataPerReadA>{};
reg_b[0] = *reinterpret_cast<const typeB*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] = constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
*reinterpret_cast<const typeB*>(&p_b_block[(mMyThreadOffsetB + NPerLevel1Cluster)]); decltype(b_thread_mtx),
reg_a[1] = KPerThreadLoop,
*reinterpret_cast<const typeA*>(&p_a_block[(mMyThreadOffsetA + MPerLevel1Cluster)]); NPerThreadSubC,
outerProduct<EPack>(reg_a[0], reg_b[0], &reg_c[0]); DataPerReadB>{};
outerProduct<EPack>(reg_a[0], reg_b[1], &reg_c[1]);
constexpr auto threadwise_gemm =
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_mtx),
decltype(b_thread_mtx),
decltype(c_thread_mtx)>{};
#pragma unroll #pragma unroll
for(index_t k = 1; k < K; ++k) // loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{ {
reg_a[0] = *reinterpret_cast<const typeA*>(&p_a_block[(mMyThreadOffsetA + k * M)]); #pragma unroll
outerProduct<EPack>(reg_a[1], reg_b[0], &reg_c[NRepeat * MPerThreadSubC]); // read A
reg_b[0] = *reinterpret_cast<const typeB*>(&p_b_block[(mMyThreadOffsetB + k * N)]); for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
outerProduct<EPack>(reg_a[1], reg_b[1], &reg_c[NRepeat * MPerThreadSubC + 1]); {
reg_b[1] = *reinterpret_cast<const typeB*>( a_thread_copy.Run(
&p_b_block[(mMyThreadOffsetB + k * N + NPerLevel1Cluster)]); p_a_block + a_block_mtx.CalculateOffset(k_begin, m_repeat * MPerLevel1Cluster) +
reg_a[1] = *reinterpret_cast<const typeA*>( mMyThreadOffsetA,
&p_a_block[(mMyThreadOffsetA + k * M + MPerLevel1Cluster)]); p_a_thread + a_thread_mtx.CalculateOffset(0, m_repeat * MPerThreadSubC));
outerProduct<EPack>(reg_a[0], reg_b[0], &reg_c[0]); }
outerProduct<EPack>(reg_a[0], reg_b[1], &reg_c[1]);
#pragma unroll
// read B
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
b_thread_copy.Run(
p_b_block + b_block_mtx.CalculateOffset(k_begin, n_repeat * NPerLevel1Cluster) +
mMyThreadOffsetB,
p_b_thread + b_thread_mtx.CalculateOffset(0, n_repeat * NPerThreadSubC));
}
// C += A * B
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
} }
outerProduct<EPack>(reg_a[1], reg_b[0], &reg_c[NRepeat * MPerThreadSubC]);
outerProduct<EPack>(reg_a[1], reg_b[1], &reg_c[NRepeat * MPerThreadSubC + 1]);
} }
#endif
template <class FloatA, class FloatB, class FloatC> template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run_source(const FloatA* const __restrict__ p_a_block, __device__ void
const FloatB* const __restrict__ p_b_block, Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
FloatC* const __restrict__ p_c_thread) const
{ {
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_block_mtx = BlockMatrixA{}; constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
...@@ -351,91 +204,144 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -351,91 +204,144 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr index_t MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM constexpr index_t MPerLevel1Cluster =
constexpr auto a_thread_mtx = make_ConstantMatrixDescriptor( MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
Number<KPerThreadLoop>{}, Number<MPerThread>{}, Number<MPerThread>{}); constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr auto b_thread_mtx = make_ConstantMatrixDescriptor( constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
Number<KPerThreadLoop>{}, Number<NPerThread>{}, Number<NPerThread>{}); constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// thread A-sub, B-sub for copy static_assert(MRepeat == 2 && NRepeat == 2,
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( "wrong! inline asm cannot deal with this GEMM config yet");
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( // thread A, B
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{}); constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = a_thread_mtx.MakeSubMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{});
constexpr auto b_thread_sub_mtx = b_thread_mtx.MakeSubMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{});
// thread C-sub
constexpr auto c_thread_sub_mtx = ThreadMatrixC::MakeSubMatrixDescriptor(
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixA,
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; decltype(a_thread_mtx),
KPerThreadLoop,
MPerThreadSubC,
DataPerReadA>{};
constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
constexpr index_t NRepeat = NPerThread / NPerThreadSubC; decltype(b_thread_mtx),
KPerThreadLoop,
NPerThreadSubC,
DataPerReadB>{};
#pragma unroll constexpr auto threadwise_gemm =
// loop over k ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_sub_mtx),
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) decltype(b_thread_sub_mtx),
{ decltype(c_thread_sub_mtx)>{};
#pragma unroll
// copy A-sub to form A const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA;
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB;
{
threadwise_matrix_copy( // read A_sub_0
a_block_mtx, a_thread_copy.Run(p_a_block_off, p_a_thread);
p_a_block +
a_block_mtx.GetOffsetFromMultiIndex(k_begin, m_repeat * MPerLevel1Cluster) + // read B_sub_0
mMyThreadOffsetA, b_thread_copy.Run(p_b_block_off, p_b_thread);
a_thread_mtx,
p_a_thread + a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC), // read B_sub_1
a_thread_sub_mtx.GetLengths(), b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(0, NPerLevel1Cluster),
Number<DataPerReadA>{}); p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC));
}
// read A_sub_1
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(0, MPerLevel1Cluster),
p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC));
#pragma unroll #pragma unroll
// copy B-sub to form B // loop over rest of k
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop)
{ {
threadwise_matrix_copy( // read A_sub_0
b_block_mtx, a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, 0), p_a_thread);
p_b_block +
b_block_mtx.GetOffsetFromMultiIndex(k_begin, n_repeat * NPerLevel1Cluster) + // C_sub_10 += transpose(A_sub_1) * B_sub_0
mMyThreadOffsetB, threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
b_thread_mtx, p_b_thread,
p_b_thread + b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC), p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0));
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{}); // read B_sub_0
} b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, 0), p_b_thread);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
p_c_thread +
ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC));
// read B_sub_1
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, NPerLevel1Cluster),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC));
// read A_sub_1
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, MPerLevel1Cluster),
p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC));
}
// C = A * B // C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm(a_thread_mtx, threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
True,
p_a_thread,
b_thread_mtx,
False,
p_b_thread, p_b_thread,
c_thread_mtx, p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0));
False,
p_c_thread); // C_sub_11 += transpose(A_sub_1) * B_sub_1
} threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
p_c_thread +
ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC));
} }
template <class FloatA, class FloatB, class FloatC> template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* __restrict__ p_a_block, __device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{ {
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM #if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
// The assembly path doesn't support bfloat16 using asm instructions constexpr index_t MPerThread = ThreadMatrixC::NRow();
#if MIOPEN_USE_BFP16 == 1 constexpr index_t NPerThread = ThreadMatrixC::NCol();
Run_source(p_a_block, p_b_block, p_c_thread);
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_if<MRepeat == 2 && NRepeat == 2>{}([&](auto) {
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
}).Else([&](auto) { Run_naive(p_a_block, p_b_block, p_c_thread); });
#else #else
Run_amd_asm(p_a_block, p_b_block, p_c_thread); Run_naive(p_a_block, p_b_block, p_c_thread);
#endif #endif
#else
Run_source(p_a_block, p_b_block, p_c_thread);
#endif // CK_USE_AMD_INLINE_ASM
} }
}; };
......
...@@ -95,9 +95,21 @@ __device__ void WaveWiseGemmMx64(const FloatA* const __restrict__ p_a_wave, ...@@ -95,9 +95,21 @@ __device__ void WaveWiseGemmMx64(const FloatA* const __restrict__ p_a_wave,
(mfma_info::group_size * mfma_info::num_blks_wave) + (mfma_info::group_size * mfma_info::num_blks_wave) +
a_off; // A is transposed a_off; // A is transposed
index_t bindex = b_off + lane_b + n * mfma_info::num_threads_blk; index_t bindex = b_off + lane_b + n * mfma_info::num_threads_blk;
p_c_thread[m + n * output_m + b * output_m * mfma_info::num_blks_wave] += // p_c_thread[m + n * output_m + b * output_m * mfma_info::num_blks_wave] +=
math::inner_product_with_conversion<FloatC>{}(p_a_wave[aindex], // math::inner_product_with_conversion<FloatC>{}(p_a_wave[aindex],
p_b_wave[bindex]); // p_b_wave[bindex]);
index_t cindex = m + n * output_m + b * output_m * mfma_info::num_blks_wave;
if(blockIdx.x*blockDim.x + threadIdx.x == 0 && cindex == 0)
{
printf("Run p_c[%d] = %f, p_a[%d] = %f, p_b[%d] = %f\n",
cindex,
p_c_thread[cindex],
aindex,
p_a_wave[aindex],
bindex,
p_b_wave[bindex]);
p_c_thread[cindex+k] = p_a_wave[aindex];
}
} }
} }
} }
...@@ -251,6 +263,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops ...@@ -251,6 +263,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
constexpr index_t N = BlockMatrixB::NCol(); constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t K = BlockMatrixA::NRow(); constexpr index_t K = BlockMatrixA::NRow();
if(blockIdx.x*blockDim.x + threadIdx.x == 0)
printf("Run M %d, N %d, K %d\n", M, N, K);
// static_if<EnableXdlops>{}([&](auto) { // static_if<EnableXdlops>{}([&](auto) {
// WaveWiseGemmMx64_xdlops<M, // WaveWiseGemmMx64_xdlops<M,
// N, // N,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment