Commit b37cb71f authored by Wen-Heng (Jack) Chung's avatar Wen-Heng (Jack) Chung
Browse files

Enable bwd wrw

parent c5143bca
#ifndef CK_CONVOLUTION_COMMON_HPP
#define CK_CONVOLUTION_COMMON_HPP
namespace ck {
enum ConvolutionDirection
{
Forward,
BackwardData,
BackwardWeight
};
} // namespace ck
#endif
...@@ -2,22 +2,21 @@ ...@@ -2,22 +2,21 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "implicitgemm_params.hpp" #include "convolution_common.hpp"
namespace ck { namespace ck {
template <ImplicitGemmDirection conv_dir, typename WeiDesc, index_t NonVectorizedC> template <ConvolutionDirection, typename WeiDesc, index_t NonVectorizedC>
struct make_vectorized_WeiDesc struct make_vectorized_WeiDesc;
{
};
template <typename WeiDesc, index_t NonVectorizedC> template <typename WeiDesc, index_t NonVectorizedC>
struct make_vectorized_WeiDesc<ImplicitGemmDirection::ForwardData, WeiDesc, NonVectorizedC> struct make_vectorized_WeiDesc<ConvolutionDirection::Forward, WeiDesc, NonVectorizedC>
{ {
__device__ constexpr auto get(WeiDesc&) __device__ constexpr auto get(WeiDesc&)
{ {
...@@ -30,8 +29,9 @@ struct make_vectorized_WeiDesc<ImplicitGemmDirection::ForwardData, WeiDesc, NonV ...@@ -30,8 +29,9 @@ struct make_vectorized_WeiDesc<ImplicitGemmDirection::ForwardData, WeiDesc, NonV
.ReorderGivenNew2Old(Sequence<2, 0, 1>{}); .ReorderGivenNew2Old(Sequence<2, 0, 1>{});
} }
}; };
template <typename WeiDesc, index_t NonVectorizedC> template <typename WeiDesc, index_t NonVectorizedC>
struct make_vectorized_WeiDesc<ImplicitGemmDirection::BackwardWeight, WeiDesc, NonVectorizedC> struct make_vectorized_WeiDesc<ConvolutionDirection::BackwardWeight, WeiDesc, NonVectorizedC>
{ {
__device__ constexpr auto get(WeiDesc& desc) __device__ constexpr auto get(WeiDesc& desc)
{ {
...@@ -56,6 +56,7 @@ template <index_t GridSize, ...@@ -56,6 +56,7 @@ template <index_t GridSize,
class OutGlobalDesc, class OutGlobalDesc,
class ConvStrides, class ConvStrides,
class ConvDilations, class ConvDilations,
ConvolutionDirection ConvDirection,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
...@@ -83,8 +84,7 @@ template <index_t GridSize, ...@@ -83,8 +84,7 @@ template <index_t GridSize,
class WeiBlockCopySrcAccessOrder, class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder, class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K, index_t WeiBlockCopyDstDataPerWrite_K>
ImplicitGemmDirection conv_dir>
struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
...@@ -198,27 +198,28 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -198,27 +198,28 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockSize,
decltype(in_e_n1_b_n2_2eor4e_global_merged_desc), decltype(in_e_n1_b_n2_2eor4e_global_merged_desc),
decltype(in_e_n1_b_n2_2eor4e_block_desc), decltype(in_e_n1_b_n2_2eor4e_block_desc),
decltype( decltype(in_e_n1_b_n2_2eor4e_block_desc.GetLengths()),
in_e_n1_b_n2_2eor4e_block_desc.GetLengths()), InBlockCopySubLengths_E_N1_B_N2_EPACK,
InBlockCopySubLengths_E_N1_B_N2_EPACK, InBlockCopyClusterLengths_E_N1_B_N2_EPACK,
InBlockCopyClusterLengths_E_N1_B_N2_EPACK, InBlockCopyThreadClusterArrangeOrder,
InBlockCopyThreadClusterArrangeOrder, InBlockCopySrcAccessOrder,
InBlockCopySrcAccessOrder, InBlockCopyDstAccessOrder,
InBlockCopyDstAccessOrder, 2,
2, 4,
4, InBlockCopySrcDataPerRead_B,
InBlockCopySrcDataPerRead_B, InBlockCopyDstDataPerWrite_EPACK>({0, 0, b_block_data_on_global, 0, 0},
InBlockCopyDstDataPerWrite_EPACK>( {0, 0, 0, 0, 0});
{0, 0, b_block_data_on_global, 0, 0}, {0, 0, 0, 0, 0});
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_2eor4e_global_desc = constexpr auto wei_e_k_2eor4e_global_desc =
make_vectorized_WeiDesc<conv_dir, decltype(wei_k_c_y_x_global_desc), nonVectorizedC>{} make_vectorized_WeiDesc<ConvDirection,
decltype(wei_k_c_y_x_global_desc),
nonVectorizedC>{}
.get(wei_k_c_y_x_global_desc); .get(wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
...@@ -235,21 +236,20 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -235,21 +236,20 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockSize,
decltype(wei_e_k_2eor4e_global_desc), decltype(wei_e_k_2eor4e_global_desc),
decltype(wei_e_k_2eor4e_block_desc), decltype(wei_e_k_2eor4e_block_desc),
decltype(wei_e_k_2eor4e_block_desc.GetLengths()), decltype(wei_e_k_2eor4e_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K_EPACK, WeiBlockCopySubLengths_E_K_EPACK,
WeiBlockCopyClusterLengths_E_K_EPACK, WeiBlockCopyClusterLengths_E_K_EPACK,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder, WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder, WeiBlockCopyDstAccessOrder,
0, 0,
2, 2,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_EPACK>( WeiBlockCopyDstDataPerWrite_EPACK>({0, k_block_data_on_global, 0}, {0, 0, 0});
{0, k_block_data_on_global, 0}, {0, 0, 0});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -279,7 +279,6 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -279,7 +279,6 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
EPACK,
decltype(a_e_k_block_mtx_desc), decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc), decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k2_n1n2_thread_mtx_desc), decltype(c_k0k2_n1n2_thread_mtx_desc),
...@@ -347,12 +346,12 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -347,12 +346,12 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto fwd) { static_if<ConvDirection == ConvolutionDirection::BackwardWeight>{}([&](auto fwd) {
fwd(blockwise_wei_copy).MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True); fwd(blockwise_wei_copy).MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0>{}, True);
}).Else([&](auto fwd) { }).Else([&](auto fwd) {
p_wei_block_on_global += p_wei_block_on_global +=
EPerBlock * fwd(wei_e_k_2eor4e_global_desc).GetStride(I0); EPerBlock * fwd(wei_e_k_2eor4e_global_desc).GetStride(I0);
...@@ -361,9 +360,8 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -361,9 +360,8 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
const typename vector_type<Float, EPACK>::MemoryType* p_a_block_vec = const typename vector_type<Float, EPACK>::MemoryType* p_a_block_vec =
...@@ -375,20 +373,20 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -375,20 +373,20 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto fwd) { static_if<ConvDirection == ConvolutionDirection::BackwardWeight>{}([&](auto fwd) {
fwd(blockwise_wei_copy).MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True); fwd(blockwise_wei_copy).MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0>{}, True);
}).Else([&](auto fwd) { }).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_2eor4e_global_desc).GetStride(I0); p_wei_block_on_global += EPerBlock * fwd(wei_e_k_2eor4e_global_desc).GetStride(I0);
}); });
...@@ -396,8 +394,8 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -396,8 +394,8 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are // Vectorize the pointer to match with how half/bfloat16 datatypes are
...@@ -415,10 +413,10 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -415,10 +413,10 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -479,7 +477,7 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b ...@@ -479,7 +477,7 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0); k_thread_data_on_global, 0, b_thread_data_on_global, 0);
ThreadwiseGenericTensorSliceCopy_v1r2< ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()), decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KC1x1_NKHW_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KC1x1_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "implicitgemm_params.hpp" #include "convolution_common.hpp"
namespace ck { namespace ck {
...@@ -21,7 +21,7 @@ template <index_t GridSize, ...@@ -21,7 +21,7 @@ template <index_t GridSize,
class WeiGlobalDesc, class WeiGlobalDesc,
class OutGlobalDesc, // exchanged outside for backward class OutGlobalDesc, // exchanged outside for backward
class ConvStrides, class ConvStrides,
ImplicitGemmDirection Direction, ConvolutionDirection ConvDirection,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
...@@ -56,7 +56,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer ...@@ -56,7 +56,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
constexpr bool isForward = Direction == ImplicitGemmDirection::ForwardData; constexpr bool isForward = (ConvDirection == ConvolutionDirection::Forward);
// this is a mess // this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters // TODO: find more elegent way of specifying (or calculating) performance parameters
...@@ -161,21 +161,20 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer ...@@ -161,21 +161,20 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockSize,
decltype(in_e_n1_b_n2_global_merged_desc), decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc), decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()), decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2, InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2, InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder, InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder, InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder, InBlockCopyDstAccessOrder,
2, 2,
3, 3,
InBlockCopySrcDataPerRead_B, InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>( InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
...@@ -198,19 +197,19 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer ...@@ -198,19 +197,19 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K, WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder, WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder, WeiBlockCopyDstAccessOrder,
0, 0,
1, 1,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
...@@ -239,7 +238,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer ...@@ -239,7 +238,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
1, // EPACK = 1
decltype(a_e_k_block_mtx_desc), decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc), decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k2_n1n2_thread_mtx_desc), decltype(c_k0k2_n1n2_thread_mtx_desc),
...@@ -301,51 +299,50 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer ...@@ -301,51 +299,50 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -437,7 +434,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer ...@@ -437,7 +434,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0); k_thread_data_on_global, 0, b_thread_data_on_global, 0);
ThreadwiseGenericTensorSliceCopy_v1r2< ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()), decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
......
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_FP16_BFP16_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_FP16_BFP16_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "implicitgemm_params.hpp" #include "implicitgemm_params.hpp"
namespace ck { namespace ck {
...@@ -173,12 +173,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds ...@@ -173,12 +173,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2_deprecated<
BlockSize, BlockSize,
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc), decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
NormalTensorCoordinate<decltype(in_e_b_block_desc)>,
decltype(in_e_b_block_desc.GetLengths()), decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B, InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B, InBlockCopyClusterLengths_E_B,
...@@ -209,12 +207,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds ...@@ -209,12 +207,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2_deprecated<
BlockSize, BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>,
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K, WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopyClusterLengths_E_K,
...@@ -300,22 +296,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds ...@@ -300,22 +296,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) { static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) {
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True); blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0>{}, True);
}).Else([&](auto fwd) { }).Else([&](auto ) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
}); });
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
const typename vector_type<Float, EPack>::MemoryType* p_a_block_vec = const typename vector_type<Float, EPack>::MemoryType* p_a_block_vec =
...@@ -327,29 +322,29 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds ...@@ -327,29 +322,29 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) { static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) {
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True); blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0>{}, True);
}).Else([&](auto fwd) { }).Else([&](auto ) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
}); });
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are // Vectorize the pointer to match with how half/bfloat16 datatypes are
...@@ -368,10 +363,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds ...@@ -368,10 +363,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -426,11 +421,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds ...@@ -426,11 +421,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col; b_block_data_on_global + c_thread_mtx_on_block.col;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1< auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
decltype(out_k0_k1_k2_b_thread_desc), decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc), decltype(out_k0_k1_k2_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_k2_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_k2_b_global_desc)>,
OutThreadCopySliceLengths, OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
......
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KC1X1_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KC1X1_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "implicitgemm_params.hpp" #include "implicitgemm_params.hpp"
namespace ck { namespace ck {
...@@ -131,21 +131,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu ...@@ -131,21 +131,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc), decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>, decltype(in_e_b_block_desc.GetLengths()),
NormalTensorCoordinate<decltype(in_e_b_block_desc)>, InBlockCopySubLengths_E_B,
decltype(in_e_b_block_desc.GetLengths()), InBlockCopyClusterLengths_E_B,
InBlockCopySubLengths_E_B, InBlockCopyThreadClusterArrangeOrder,
InBlockCopyClusterLengths_E_B, InBlockCopySrcAccessOrder,
InBlockCopyThreadClusterArrangeOrder, InBlockCopyDstAccessOrder,
InBlockCopySrcAccessOrder, 1,
InBlockCopyDstAccessOrder, 1,
1, InBlockCopyDataPerAccess_B,
1, InBlockCopyDataPerAccess_B>(
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0}); {0, b_block_data_on_global}, {0, 0});
// weight tensor // weight tensor
...@@ -167,22 +165,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu ...@@ -167,22 +165,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_wei_copy =
BlockSize, BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>, decltype(wei_e_k_block_desc.GetLengths()),
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>, WeiBlockCopySubLengths_E_K,
decltype(wei_e_k_block_desc.GetLengths()), WeiBlockCopyClusterLengths_E_K,
WeiBlockCopySubLengths_E_K, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopySrcAccessOrder,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcAccessOrder, 0,
WeiBlockCopyDstAccessOrder, 1,
0, WeiBlockCopySrcDataPerRead_E,
1, WeiBlockCopyDstDataPerWrite_K>(
WeiBlockCopySrcDataPerRead_E, {0, k_block_data_on_global}, {0, 0});
WeiBlockCopyDstDataPerWrite_K>({0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
...@@ -253,51 +250,50 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu ...@@ -253,51 +250,50 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0]; p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0];
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0]; p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0];
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -373,11 +369,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu ...@@ -373,11 +369,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_bu
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col; b_block_data_on_global + c_thread_mtx_on_block.col;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1< auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
decltype(out_k0_k1_k2_b_thread_desc), decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc), decltype(out_k0_k1_k2_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_k2_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_k2_b_global_desc)>,
OutThreadCopySliceLengths, OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
......
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "implicitgemm_params.hpp" #include "implicitgemm_params.hpp"
namespace ck { namespace ck {
...@@ -80,8 +80,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -80,8 +80,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
if(blockIdx.x*blockDim.x + threadIdx.x == 0)
printf("conv dir %d",conv_dir);
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -162,21 +160,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -162,21 +160,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize, BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc), decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>, decltype(in_e_b_block_desc.GetLengths()),
NormalTensorCoordinate<decltype(in_e_b_block_desc)>, InBlockCopySubLengths_E_B,
decltype(in_e_b_block_desc.GetLengths()), InBlockCopyClusterLengths_E_B,
InBlockCopySubLengths_E_B, InBlockCopyThreadClusterArrangeOrder,
InBlockCopyClusterLengths_E_B, InBlockCopySrcAccessOrder,
InBlockCopyThreadClusterArrangeOrder, InBlockCopyDstAccessOrder,
InBlockCopySrcAccessOrder, 1,
InBlockCopyDstAccessOrder, 1,
1, InBlockCopyDataPerAccess_B,
1, InBlockCopyDataPerAccess_B>(
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0}); {0, b_block_data_on_global}, {0, 0});
// weight tensor // weight tensor
...@@ -185,7 +181,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -185,7 +181,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
make_WeiDesc_Xdlops<conv_dir, decltype(wei_k_c_y_x_global_desc)>{}.get( make_WeiDesc_Xdlops<conv_dir, decltype(wei_k_c_y_x_global_desc)>{}.get(
wei_k_c_y_x_global_desc); wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{}, Sequence<EPerBlock, KPerBlock>{},
...@@ -194,22 +190,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -194,22 +190,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< auto blockwise_wei_copy =
BlockSize, BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>, decltype(wei_e_k_block_desc.GetLengths()),
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>, WeiBlockCopySubLengths_E_K,
decltype(wei_e_k_block_desc.GetLengths()), WeiBlockCopyClusterLengths_E_K,
WeiBlockCopySubLengths_E_K, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopySrcAccessOrder,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcAccessOrder, 0,
WeiBlockCopyDstAccessOrder, 1,
0, WeiBlockCopySrcDataPerRead_E,
1, WeiBlockCopyDstDataPerWrite_K>(
WeiBlockCopySrcDataPerRead_E, {0, k_block_data_on_global}, {0, 0});
WeiBlockCopyDstDataPerWrite_K>({0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
...@@ -261,6 +256,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -261,6 +256,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
} }
#if 1
// LDS double buffer: main body // LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock) e_block_data_begin += 2 * EPerBlock)
...@@ -280,59 +276,51 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -280,59 +276,51 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) { blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
} }
} }
#endif
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) { blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_block_on_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -384,11 +372,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -384,11 +372,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col; b_block_data_on_global + c_thread_mtx_on_block.col;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1< auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
decltype(out_k0_k1_k2_b_thread_desc), decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc), decltype(out_k0_k1_k2_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_k2_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_k2_b_global_desc)>,
OutThreadCopySliceLengths, OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP #define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor_deprecated.hpp"
#include "tensor_descriptor.hpp"
namespace ck { namespace ck {
...@@ -31,6 +32,11 @@ struct ConstantMatrixDescriptor ...@@ -31,6 +32,11 @@ struct ConstantMatrixDescriptor
return irow * RowStride_ + icol; return irow * RowStride_ + icol;
} }
__host__ __device__ static index_t CalculateOffset(index_t irow, index_t icol)
{
return GetOffsetFromMultiIndex(irow, icol);
}
template <index_t SubNRow, index_t SubNCol> template <index_t SubNRow, index_t SubNCol>
__host__ __device__ static constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>, __host__ __device__ static constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
Number<SubNCol>) Number<SubNCol>)
...@@ -52,10 +58,22 @@ __host__ __device__ constexpr auto ...@@ -52,10 +58,22 @@ __host__ __device__ constexpr auto
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{}; return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
} }
template <class... Ts> template <typename... Ts>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorDescriptor<Ts...>) __host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>)
{
using TDesc = ConstantTensorDescriptor_deprecated<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
static_assert(TDesc::GetStrides()[1] == 1, "wrong");
return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
TDesc::GetLengths()[1],
TDesc::GetStrides()[0]>{};
}
template <typename... Ts>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(NativeTensorDescriptor<Ts...>)
{ {
using TDesc = ConstantTensorDescriptor<Ts...>; using TDesc = NativeTensorDescriptor<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong"); static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
static_assert(TDesc::GetStrides()[1] == 1, "wrong"); static_assert(TDesc::GetStrides()[1] == 1, "wrong");
return ConstantMatrixDescriptor<TDesc::GetLengths()[0], return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
...@@ -63,7 +81,7 @@ __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorD ...@@ -63,7 +81,7 @@ __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorD
TDesc::GetStrides()[0]>{}; TDesc::GetStrides()[0]>{};
} }
template <class TDesc> template <typename TDesc>
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s) __host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
{ {
printf( printf(
......
#ifndef CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#define CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
namespace ck {
// OriginalTensorDesc : ConstantTensorDescriptor_deprecated<...>
// it's the tensor whose dimensions are to be merged
// OriginalDimMergeSeqs : Sequence<...>...
// each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
struct ConstantMergedTensorDescriptor_deprecated
{
using Type = ConstantMergedTensorDescriptor_deprecated;
static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};
static constexpr index_t nDim = sizeof...(OriginalDimMergeSeqs);
static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
__host__ __device__ constexpr ConstantMergedTensorDescriptor_deprecated()
{
static_assert(nDim <= nOriginalDim, "wrong!");
// TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
// OriginalTensorDesc::nDim number of dimensions
// TODO: check OriginalDimMergeSeqs contains all original dimensions
// TODO: check there is no duplication in OriginalDimMergeSeqs
}
__host__ __device__ static constexpr auto GetOriginalTensorDescriptor()
{
return OriginalTensorDesc{};
}
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
template <index_t IDim>
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
{
return std::get<IDim>(mOriginalDimMergeSeqs);
}
template <index_t IDim>
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
{
return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{
constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
{
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
"wrong! stride of a merged dimension is undefined");
constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
return OriginalTensorDesc::GetStride(Number<idim_original>{});
}
// this is a hack to return the stride of the last original dimension of a merged dimension
// TODO: refactor this once the concept of "dimension" is used
template <index_t IDim>
__host__ __device__ static constexpr auto GetLastOriginalDimensionStride(Number<IDim>)
{
constexpr auto idim_last_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
return OriginalTensorDesc::GetStride(Number<idim_last_original>{});
}
__host__ __device__ static constexpr auto GetLengths()
{
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
}
__host__ __device__ static constexpr auto GetElementSize()
{
return OriginalTensorDesc::GetElementSize();
}
template <class OriginalDimsPartial>
struct lambda_1_GetOriginalMultiIndexFromMultiIndex
{
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial;
Array<index_t, nOriginalDim>& original_multi_id;
__host__ __device__ constexpr lambda_1_GetOriginalMultiIndexFromMultiIndex(
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial_,
Array<index_t, nOriginalDim>& original_multi_id_)
: original_multi_id_partial(original_multi_id_partial_),
original_multi_id(original_multi_id_)
{
}
template <index_t I>
__host__ __device__ constexpr void operator()(Number<I>) const
{
constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{});
index_t itmp = original_multi_id_partial[I];
original_multi_id(idim_original) = itmp;
}
};
struct lambda_0_GetOriginalMultiIndexFromMultiIndex
{
const Array<index_t, nDim>& multi_id;
Array<index_t, nOriginalDim>& original_multi_id;
__host__ __device__ constexpr lambda_0_GetOriginalMultiIndexFromMultiIndex(
const Array<index_t, nDim>& multi_id_, Array<index_t, nOriginalDim>& original_multi_id_)
: multi_id(multi_id_), original_multi_id(original_multi_id_)
{
}
template <index_t IDim>
__host__ __device__ constexpr void operator()(Number<IDim>) const
{
constexpr auto original_dims_partial = std::get<IDim>(Type::mOriginalDimMergeSeqs);
// get partial original-multi-id corresponding to this merged dimension
const auto original_multi_id_partial =
OriginalTensorDesc::Extract(original_dims_partial)
.GetMultiIndexFrom1dIndex(multi_id[IDim]);
static_for<0, original_dims_partial.GetSize(), 1>{}(
lambda_1_GetOriginalMultiIndexFromMultiIndex<decltype(original_dims_partial)>(
original_multi_id_partial, original_multi_id));
}
};
// return type is Array<...>
__host__ __device__ static constexpr auto
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
{
Array<index_t, nOriginalDim> original_multi_id;
static_for<0, nDim, 1>{}(
lambda_0_GetOriginalMultiIndexFromMultiIndex(multi_id, original_multi_id));
return original_multi_id;
}
template <index_t... Is>
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
{
constexpr auto multi_id = sequence2array(Sequence<Is...>{});
constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
}
__host__ __device__ static constexpr index_t
GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
{
auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
}
template <class... Is>
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
{
return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
}
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
{
constexpr auto packed_desc = make_ConstantTensorDescriptor_packed(GetLengths());
return packed_desc.GetMultiIndexFrom1dIndex(id);
}
__host__ __device__ static constexpr auto Pack()
{
constexpr auto lengths = GetLengths();
constexpr auto strides = calculate_tensor_strides_packed(lengths);
return ConstantTensorDescriptor_deprecated<decltype(lengths), decltype(strides)>{};
}
};
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
OriginalDimMergeSeqs...)
{
return ConstantMergedTensorDescriptor_deprecated<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
}
template <class TDesc>
__host__ __device__ void print_ConstantMergedTensorDescriptor(const char* s, TDesc)
{
print_ConstantTensorDescriptor(s, TDesc::GetOriginalTensorDescriptor());
}
} // namespace ck
#endif
#ifndef CK_DIMENSION_HPP
#define CK_DIMENSION_HPP
#include "common_header.hpp"
namespace ck {
template <index_t Length, index_t Stride>
struct NativeDimension
{
__host__ __device__ static constexpr auto GetLength() { return Number<Length>{}; }
__host__ __device__ static constexpr auto GetStride() { return Number<Stride>{}; }
};
} // namespace ck
#endif
#ifndef CK_MULTI_INDEX_TRANSFORM_HPP
#define CK_MULTI_INDEX_TRANSFORM_HPP
#include "common_header.hpp"
namespace ck {
template <index_t N>
using MultiIndex = Array<index_t, N>;
template <typename... Xs>
__host__ __device__ constexpr auto make_multi_index(Xs... xs)
{
return MultiIndex<sizeof...(Xs)>(xs...);
}
template <index_t Length>
struct PassThrough
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<1>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<1>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<Length>{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
return idx_up;
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return idx_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{
return true;
}
};
// LowerLengths: Sequence<...>
template <typename LowerLengths, typename LeftPads, typename RightPads>
struct Pad
{
static constexpr index_t nDim = LowerLengths::Size();
using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return LowerLengths{} + LeftPads{} + RightPads{};
}
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
return idx_up - LeftPads{};
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return idx_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const
{
#if 0
struct lambda_no_pad
{
__host__ __device__ constexpr bool operator()(index_t x) const { return x == 0; }
};
if(sequence_all_of(LeftPads{}, lambda_no_pad{}) &&
sequence_all_of(RightPads{}, lambda_no_pad{}))
{
return true;
}
else
#endif
{
bool flag = true;
static_for<0, nDim, 1>{}([&](auto idim) {
// only check if there is left-padding
static_if<(LeftPads::At(idim) != 0)>{}(
[&](auto) { flag = flag && idx_up[idim] >= LeftPads::At(idim); });
// only check if there is right-padding
static_if<(RightPads::At(idim) != 0)>{}([&](auto) {
flag = flag && (idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
});
});
return flag;
}
}
};
// LowerLengths: Sequence<...>
template <typename LowerLengths>
struct Merge
{
static constexpr index_t nDimLow = LowerLengths::Size();
static constexpr index_t nDimUp = 1;
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return Sequence<reduce_on_sequence(
LowerLengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
}
// emulate constexpr lambda
template <typename PseudoLowStrides>
struct lambda_CalculateLowerIndex
{
index_t& itmp;
LowerIndex& idx_low;
__host__ __device__ explicit constexpr lambda_CalculateLowerIndex(index_t& itmp_,
LowerIndex& idx_low_)
: itmp(itmp_), idx_low(idx_low_)
{
}
template <typename IDim>
__host__ __device__ constexpr void operator()(IDim idim) const
{
constexpr index_t stride = PseudoLowStrides::At(idim);
idx_low(idim) = itmp / stride;
itmp -= idx_low[idim] * stride;
}
};
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low;
index_t itmp = idx_up[0];
constexpr auto pseudo_low_strides =
reverse_inclusive_scan_sequence(
LowerLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
static_for<0, nDimLow - 1, 1>{}(
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
idx_low(nDimLow - 1) = itmp / pseudo_low_strides[nDimLow - 1];
return idx_low;
}
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
// If idx_up_diff is known at compile-time, many calculations can be optimized
// away by compiler
// This function assume idx_low_old is not out-of-bound
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& idx_low_old)
{
// do nothing if idx_up_diff == 0
if(idx_up_diff[0] == 0)
{
return make_zero_array<index_t, nDimLow>();
}
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
// If idx_up_diff is known at compile-time, the calculation can
// be done at compile-time. However, if idx_up_diff is only known
// at run-time, then the calculation will also be computed at
// run-time, and can be very expensive.
LowerIndex idx_low_new = idx_low_old + CalculateLowerIndex(idx_up_diff);
if(idx_up_diff[0] > 0)
{
bool carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
constexpr index_t i = nDimLow - 1 - ireverse;
if(carry)
{
++idx_low_new(i);
}
carry = false;
if(idx_low_new[i] >= LowerLengths::At(i))
{
idx_low_new(i) -= LowerLengths::At(i);
carry = true;
}
});
// highest dimension, no out-of-bound check
if(carry)
{
++idx_low_new(0);
}
}
else if(idx_up_diff[0] < 0)
{
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
constexpr index_t i = nDimLow - 1 - ireverse;
if(borrow)
{
--idx_low_new(i);
}
borrow = false;
if(idx_low_new[i] < 0)
{
idx_low_new(i) += LowerLengths::At(i);
borrow = true;
}
});
// highest dimension, no out-of-bound check
if(borrow)
{
--idx_low_new(0);
}
}
return idx_low_new - idx_low_old;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{
return true;
}
};
// UpperLengths: Sequence<...>
template <typename UpperLengths>
struct UnMerge
{
static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpperLengths::Size();
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low{0};
constexpr auto pseudo_up_strides =
reverse_inclusive_scan_sequence(
UpperLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low(0) += idx_up[idim] * pseudo_up_strides[idim]; });
return idx_low;
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return CalculateLowerIndex(idx_up_diff);
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{
return true;
}
};
// UpperLengths: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template <typename UpperLengths, typename Coefficients>
struct Embed
{
static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpperLengths::Size();
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ explicit constexpr Embed()
{
static_assert(UpperLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
"wrong! # of dimensions not consistent");
}
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; }
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low(Coefficients{}[nDimUp]);
static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low(0) += idx_up[idim] * Coefficients{}[idim]; });
return idx_low;
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
LowerIndex idx_low_diff{0};
static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low_diff(0) += idx_up_diff[idim] * Coefficients{}[idim]; });
return idx_low_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{
return true;
}
};
template <index_t LowerLength, index_t VectorSize>
struct Vectorize
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
__host__ __device__ constexpr Vectorize()
{
static_assert(VectorSize > 0 && LowerLength % VectorSize == 0,
"wrong! cannot evenly divide");
}
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<1>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<1>{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return Sequence<LowerLength / VectorSize>{};
}
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
return VectorSize * idx_up;
}
__host__ __device__ static constexpr auto
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const UpperIndex& /* idx_up_old */,
const LowerIndex& /* idx_low_old */)
{
return VectorSize * idx_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{
return true;
}
};
} // namespace ck
#endif
...@@ -2,299 +2,248 @@ ...@@ -2,299 +2,248 @@
#define CK_TENSOR_COORDINATE_HPP #define CK_TENSOR_COORDINATE_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "dimension.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "multi_index_transform.hpp"
#include "tensor_descriptor.hpp"
namespace ck { namespace ck {
template <class TensorDesc> // A "tensor cooridnate" is an opaque object that represents a "point of location" inside a tensor
struct NormalTensorCoordinate // At the bare minimun, user should be able to query the following information from a tensor
// coordinate:
// 1. Tensor descriptor
// 2. Location, represented in the form of multi-index
// 3. Location, represented in the form of the offset to the origin of the tensor
// 4. If the location is inside invalid area or not, i.e. the padding area of an implicitly padded
// tensor is considered invalid, because the padding area doesn't have any physical memory
// allocation
// A tensor cooridnate also provides following functionality:
// 1. Given step size in each dimension, update itself, or return a new tensor cooridnate, so user
// can freely move the "point of location" inside the tensor
// wrapper class for NativeTensorCoordinate and TransformedTensorCoordinate
template <typename TensorDesc>
struct TensorCoordinate;
// tensor coordinate for native tensor
template <typename NativeTensorDesc>
struct NativeTensorCoordinate
{ {
using type = NormalTensorCoordinate; using type = NativeTensorCoordinate;
using tensor_desc_type = TensorDesc; using tensor_desc_type = NativeTensorDesc;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension(); static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__host__ __device__ constexpr NormalTensorCoordinate(Array<index_t, nDim> tensor_index) __host__ __device__ constexpr NativeTensorCoordinate(Index idx)
: mOffset{tensor_desc_type::GetOffsetFromMultiIndex(tensor_index)} : mIndex(idx), mOffset(tensor_desc_type::CalculateOffset(idx))
{ {
} }
template <class... Xs> template <typename... Xs>
__host__ __device__ constexpr NormalTensorCoordinate(Xs... xs) __host__ __device__ constexpr NativeTensorCoordinate(Xs... xs)
: NormalTensorCoordinate(Array<index_t, nDim>{xs...}) : NativeTensorCoordinate(Index{xs...})
{ {
} }
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; } template <index_t... Xs>
__host__ __device__ constexpr NativeTensorCoordinate(Sequence<Xs...>)
: NativeTensorCoordinate(Index{Xs...})
{
}
__host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
__host__ __device__ constexpr const Index& GetIndex() const { return mIndex; }
// T is Array or Sequence __host__ __device__ constexpr const index_t& GetOffset() const { return mOffset; }
template <class T>
__host__ __device__ type operator+=(T step_sizes) __host__ __device__ constexpr type operator+=(const Index& idx_diff)
{ {
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!"); // mIndex is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndex += idx_diff;
mOffset += tensor_desc_type::GetOffsetFromMultiIndex(step_sizes); mOffset += tensor_desc_type::CalculateOffsetDiff(idx_diff);
return *this; return *this;
} }
template <class T> __host__ __device__ constexpr type operator-=(const Index& idx_diff)
__host__ __device__ type operator-=(T step_sizes)
{ {
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!"); // mIndex is updated here, but some (or all) of its entries may never be used
// compiler should remove those entries as dead code
mIndex -= idx_diff;
mOffset -= tensor_desc_type::GetOffsetFromMultiIndex(step_sizes); mOffset -= tensor_desc_type::CalculateOffsetDiff(idx_diff);
return *this; return *this;
} }
template <class T> __host__ __device__ constexpr type operator+(const Index& idx_diff) const
__host__ __device__ constexpr type operator+(T step_sizes) const
{ {
type coord = *this; type coord = *this;
coord += step_sizes; coord += idx_diff;
return coord; return coord;
} }
template <class T> __host__ __device__ constexpr type operator-(const Index& idx_diff) const
__host__ __device__ constexpr type operator-(T step_sizes) const
{ {
type coord = *this; type coord = *this;
coord -= step_sizes; coord -= idx_diff;
return coord; return coord;
} }
// reposition point of origin, and return compensated offset. __host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff)
// This is a hack to reduce index calculation during looping over
// a tensor whose origin is this TensorCoordinate. It does so, by spitting
// out the run-time offset to the pointer (to the tensor data) held by this
// TensorCoordiante, so the caller can add the offset into the run-time pointer of
// the data, so only 1 run-time variable (update pointer) is needed, instead
// of 2 run-time variables (old pointer and this offset)
// TODO: after introducing the concept of "run-time tensor view", which contains the
// run-time pointer to the data, always keep track of the pointer, instead of both
// offset and the pointer. This also bring additional benefit that we don't need to
// worry the offset might underflow (because offset is unsigned integer) when updating it.
__host__ __device__ constexpr index_t RepositOrigin()
{ {
index_t offset_diff = mOffset; return tensor_desc_type::CalculateOffsetDiff(idx_diff);
mOffset = 0;
return offset_diff;
} }
__host__ __device__ static constexpr bool IsUpperIndexMappedToValidOffset() { return true; }
private: private:
// mIndex may be saved and updated, however, the value of some (or all) of its entries may
// never be used. Compiler should be able to remove these entries as well as its calculation
// as dead code.
// TODO: make sure compiler indeed remove these dead code
Index mIndex;
index_t mOffset; index_t mOffset;
}; };
template <class TensorDesc> // tensor coordinate for transformed tensor
struct MergedTensorCoordinate template <typename TransformedTensorDesc>
struct TransformedTensorCoordinate
{ {
using type = MergedTensorCoordinate; using tensor_desc_type = TransformedTensorDesc;
using tensor_desc_type = TensorDesc; using LowerCoord =
typename TensorCoordinate<decltype(tensor_desc_type::GetLowerTensorDescriptor())>::type;
using UpperCoord = TransformedTensorCoordinate;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension(); static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
static constexpr index_t nOriginalDim = using UpperIndex = MultiIndex<nDim>;
tensor_desc_type::GetOriginalTensorDescriptor().GetNumOfDimension();
__host__ __device__ constexpr MergedTensorCoordinate(Array<index_t, nDim> tensor_index) __host__ __device__ constexpr TransformedTensorCoordinate(UpperIndex idx)
: mOriginalIndex{tensor_desc_type::GetOriginalMultiIndexFromMultiIndex(tensor_index)} : mIndexUp{idx}, mCoordLow{tensor_desc_type::CalculateLowerIndex(idx)}
{ {
// partial offset on each dimension }
static_for<0, nDim, 1>{}([&](auto idim) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
mPartialOffsets(idim) = partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mOriginalIndex, partial_original_dims));
});
// complete offset template <typename... Xs>
mOffset = __host__ __device__ constexpr TransformedTensorCoordinate(Xs... xs)
accumulate_on_array(mPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0)); : TransformedTensorCoordinate(UpperIndex{xs...})
{
} }
template <class... Xs> template <index_t... Xs>
__host__ __device__ constexpr MergedTensorCoordinate(Xs... xs) __host__ __device__ constexpr TransformedTensorCoordinate(Sequence<Xs...>)
: MergedTensorCoordinate(Array<index_t, nDim>{xs...}) : TransformedTensorCoordinate(UpperIndex{Xs...})
{ {
} }
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; } __host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
__host__ __device__ constexpr const LowerCoord& GetLowerCoordinate() const { return mCoordLow; }
__host__ __device__ constexpr const UpperIndex& GetUpperIndex() const { return mIndexUp; }
template <class IDim, class T, bool PositiveDirection> __host__ __device__ constexpr const UpperIndex& GetIndex() const { return GetUpperIndex(); }
__host__ __device__ void
MoveOnDimension(IDim idim_, T step_size, integral_constant<bool, PositiveDirection>) __host__ __device__ constexpr const index_t& GetOffset() const
{ {
constexpr auto idim = idim_; return GetLowerCoordinate().GetOffset();
// if step_size is known at compile time
static_if<is_static<T>::value>{}(
[&](auto) { static_if<T{} == 0>{}([&](auto) { return; }); });
// update original index
static_if<tensor_desc_type::ContainMultipleOriginalDimensions(idim)>{}([&](auto) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr index_t ndim_partial_original = partial_original_dims.GetSize();
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
const auto partial_original_step_sizes =
partial_original_desc.GetMultiIndexFrom1dIndex(step_size);
// update partial original multi-id
auto partial_original_id = extract_array(mOriginalIndex, partial_original_dims);
static_if<PositiveDirection>{}([&](auto) {
partial_original_id += partial_original_step_sizes;
bool carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(carry)
{
++partial_original_id(i);
}
carry = false;
if(partial_original_id[i] >= partial_original_desc.GetLength(i))
{
partial_original_id(i) -= partial_original_desc.GetLength(i);
carry = true;
}
});
}).Else([&](auto) {
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
partial_original_id +=
partial_original_desc.GetLengths() - partial_original_step_sizes;
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(borrow)
{
--partial_original_id(i);
}
borrow = false;
if(partial_original_id[i] < partial_original_desc.GetLength(i))
{
partial_original_id(i) += partial_original_desc.GetLength(i);
borrow = true;
}
});
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
partial_original_id = partial_original_id - partial_original_desc.GetLengths();
});
// update "mOriginalIndex"
static_for<0, ndim_partial_original, 1>{}([&](auto I) {
constexpr auto idim_original = partial_original_dims[I];
mOriginalIndex(idim_original) = partial_original_id[I];
});
// calculate new partial offset on this merged dimension
const index_t old_partial_offset = mPartialOffsets[idim];
mPartialOffsets(idim) =
partial_original_desc.GetOffsetFromMultiIndex(partial_original_id);
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mOffset = (mOffset + mPartialOffsets[idim]) - old_partial_offset;
}).Else([&](auto fwd) {
static_if<PositiveDirection>{}([&](auto) {
mOffset += step_size * fwd(tensor_desc_type{}).GetStride(idim);
}).Else([&](auto) { mOffset -= step_size * fwd(tensor_desc_type{}).GetStride(idim); });
});
} }
// T is Array or Sequence __host__ __device__ constexpr UpperCoord operator+=(const UpperIndex& idx_up_diff)
template <class T>
__host__ __device__ type operator+=(T step_sizes)
{ {
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, // For transformation of multi-index difference, not all transformation functions need to
"wrong! the rank of step size doesn't match with that of tensor coordinate"); // know the old lower-index or the old upper-index. We pass both of them to the
// transformation function. The transformation function itself decides to use them or not.
mCoordLow += tensor_desc_type::CalculateLowerIndexDiff(
idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex());
static_for<0, nDim, 1>{}([&](auto idim) { // mIndexUp is updated here, but some (or all) of its entries may never be used
if(step_sizes[idim] != 0) // compiler should remove those entries as dead code
{ mIndexUp += idx_up_diff;
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, true>{});
}
});
return *this; return *this;
} }
template <class T> __host__ __device__ constexpr UpperCoord operator-=(const UpperIndex& idx_up_diff)
__host__ __device__ type operator-=(T step_sizes)
{ {
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, mCoordLow -= tensor_desc_type::CalculateLowerIndexDiff(
"wrong! the rank of step size doesn't match with that of tensor coordinate"); idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex());
static_for<0, nDim, 1>{}([&](auto idim) { // mIndex is updated here, but some (or all) of its entries may never be used
if(step_sizes[idim] != 0) // compiler should remove those entries as dead code
{ mIndexUp -= idx_up_diff;
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, false>{});
}
});
return *this; return *this;
} }
template <class T> __host__ __device__ constexpr UpperCoord operator+(const UpperIndex& idx_up_diff) const
__host__ __device__ constexpr type operator+(T step_sizes) const
{ {
type coord = *this; UpperCoord coord_up = *this;
coord += step_sizes; coord_up += idx_up_diff;
return coord; return coord_up;
} }
template <class T> __host__ __device__ constexpr UpperCoord operator-(const UpperIndex& idx_up_diff) const
__host__ __device__ constexpr type operator-(T step_sizes) const
{ {
type coord = *this; UpperCoord coord_up = *this;
coord -= step_sizes; coord_up -= idx_up_diff;
return coord; return coord_up;
} }
__host__ __device__ static constexpr index_t RepositOrigin() { return 0; } // Calculate offset diff without updating tensor-coordinate
// If idx_up_diff is know at compile time, and has only non-zero entries on linear dimensions,
// then all calculation can be done at compile-time.
// TODO: this function is not compiled to expected ISA
__host__ __device__ constexpr index_t CalculateOffsetDiff(const UpperIndex& idx_up_diff) const
{
// For transformation of multi-index difference, not all transformation functions need to
// know the old lower-index or the old upper-index. We pass both of them to the
// transformation function. The transformation function itself decides to use them or not.
const auto idx_low_diff = tensor_desc_type::CalculateLowerIndexDiff(
idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex());
return GetLowerCoordinate().CalculateOffsetDiff(idx_low_diff);
}
__host__ __device__ constexpr bool IsUpperIndexMappedToValidOffset() const
{
return tensor_desc_type::IsUpperIndexMappedToValidLowerIndex(GetIndex()) &&
mCoordLow.IsUpperIndexMappedToValidOffset();
}
private: private:
// Allocate register memory for all merged dimensions and normal dimensions. // mIndexUp may be calculated and updated, however, the value of some (or all) of its entries
// However, only those merged dimensions, whose index will be involved in arithmetic // may
// after the construction of this TensorCoordinate (e.g. when user move a slicing // never be used. Compiler should be able to remove these entries as well as its calculation
// window on the merged dimension), will use these register memory. // as dead code.
// Let's hope compiler will optimize away those register memory allocated for normal // TODO: make sure compiler indeed remove these dead code
// dimensions, and those merged dimensions, that would never be involved in index UpperIndex mIndexUp;
// arithmetic after construction of TensorCoordinate. LowerCoord mCoordLow;
// TODO: refactor TensorCoordinate, after introducing the concept of "dimensions" };
// and simplify implementation of ConstantMergedTensorDescriptor, so we don't need to
// count on compiler to optimize way those register memory for us template <typename TensorDesc>
Array<index_t, nOriginalDim> mOriginalIndex; struct TensorCoordinate
Array<index_t, nDim> mPartialOffsets; {
private:
// complete offset template <typename... Ts>
index_t mOffset; __host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
{
return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>(
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
}
template <typename... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
{
return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>(
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
}
public:
using type = decltype(MakeDummyTensorCoordinate(TensorDesc{}));
}; };
} // namespace ck } // namespace ck
......
#ifndef CK_TENSOR_COORDINATE_DEPRECATED_HPP
#define CK_TENSOR_COORDINATE_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
namespace ck {
// TensorDesc is ConstantTensorDescriptor_deprecated
template <class TensorDesc>
struct NormalTensorCoordinate_deprecated
{
using type = NormalTensorCoordinate_deprecated;
using tensor_desc_type = TensorDesc;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
__host__
__device__ constexpr NormalTensorCoordinate_deprecated(Array<index_t, nDim> tensor_index)
: mOffset{tensor_desc_type::GetOffsetFromMultiIndex(tensor_index)}
{
}
template <class... Xs>
__host__ __device__ constexpr NormalTensorCoordinate_deprecated(Xs... xs)
: NormalTensorCoordinate_deprecated(Array<index_t, nDim>{xs...})
{
}
template <index_t... Xs>
__host__ __device__ constexpr NormalTensorCoordinate_deprecated(Sequence<Xs...>)
: NormalTensorCoordinate_deprecated(Array<index_t, nDim>{Xs...})
{
}
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
// T is Array or Sequence
template <class T>
__host__ __device__ type operator+=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
mOffset += tensor_desc_type::GetOffsetFromMultiIndex(step_sizes);
return *this;
}
template <class T>
__host__ __device__ type operator-=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
mOffset -= tensor_desc_type::GetOffsetFromMultiIndex(step_sizes);
return *this;
}
template <class T>
__host__ __device__ constexpr type operator+(T step_sizes) const
{
type coord = *this;
coord += step_sizes;
return coord;
}
template <class T>
__host__ __device__ constexpr type operator-(T step_sizes) const
{
type coord = *this;
coord -= step_sizes;
return coord;
}
// reposition point of origin, and return compensated offset.
// This is a hack to reduce index calculation during looping over
// a tensor whose origin is this TensorCoordinate. It does so, by spitting
// out the run-time offset to the pointer (to the tensor data) held by this
// TensorCoordiante, so the caller can add the offset into the run-time pointer of
// the data, so only 1 run-time variable (update pointer) is needed, instead
// of 2 run-time variables (old pointer and this offset)
// TODO: after introducing the concept of "run-time tensor view", which contains the
// run-time pointer to the data, always keep track of the pointer, instead of both
// offset and the pointer. This also bring additional benefit that we don't need to
// worry the offset might underflow (because offset is unsigned integer) when updating it.
__host__ __device__ constexpr index_t RepositionOrigin()
{
index_t offset_diff = mOffset;
mOffset = 0;
return offset_diff;
}
private:
index_t mOffset;
};
// TensorDesc is ConstantMergedTensorDescriptor_deprecated
template <class TensorDesc>
struct MergedTensorCoordinate_deprecated
{
using type = MergedTensorCoordinate_deprecated;
using tensor_desc_type = TensorDesc;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
static constexpr index_t nOriginalDim =
tensor_desc_type::GetOriginalTensorDescriptor().GetNumOfDimension();
__host__
__device__ constexpr MergedTensorCoordinate_deprecated(Array<index_t, nDim> tensor_index)
: mOriginalIndex{tensor_desc_type::GetOriginalMultiIndexFromMultiIndex(tensor_index)}
{
// partial offset on each dimension
static_for<0, nDim, 1>{}([&](auto idim) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
mPartialOffsets(idim) = partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mOriginalIndex, partial_original_dims));
});
// complete offset
mOffset =
accumulate_on_array(mPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
}
template <class... Xs>
__host__ __device__ constexpr MergedTensorCoordinate_deprecated(Xs... xs)
: MergedTensorCoordinate_deprecated(Array<index_t, nDim>{xs...})
{
}
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
template <class IDim, class T, bool PositiveDirection>
__host__ __device__ void
MoveOnDimension(IDim idim_, T step_size, integral_constant<bool, PositiveDirection>)
{
constexpr auto idim = idim_;
// if step_size is known at compile time
static_if<is_static<T>::value>{}(
[&](auto) { static_if<T{} == 0>{}([&](auto) { return; }); });
// update original index
static_if<tensor_desc_type::ContainMultipleOriginalDimensions(idim)>{}([&](auto) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr index_t ndim_partial_original = partial_original_dims.GetSize();
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
const auto partial_original_step_sizes =
partial_original_desc.GetMultiIndexFrom1dIndex(step_size);
// update partial original multi-id
auto partial_original_id = extract_array(mOriginalIndex, partial_original_dims);
static_if<PositiveDirection>{}([&](auto) {
partial_original_id += partial_original_step_sizes;
bool carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original - 1, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(carry)
{
++partial_original_id(i);
}
carry = false;
if(partial_original_id[i] >= partial_original_desc.GetLength(i))
{
partial_original_id(i) -= partial_original_desc.GetLength(i);
carry = true;
}
});
// highest dimension
if(carry)
{
++partial_original_id(0);
}
}).Else([&](auto) {
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
partial_original_id +=
partial_original_desc.GetLengths() - partial_original_step_sizes;
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original - 1, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(borrow)
{
--partial_original_id(i);
}
borrow = false;
if(partial_original_id[i] < partial_original_desc.GetLength(i))
{
partial_original_id(i) += partial_original_desc.GetLength(i);
borrow = true;
}
});
// highest dimension
if(borrow)
{
--partial_original_id(0);
}
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
partial_original_id = partial_original_id - partial_original_desc.GetLengths();
});
// update "mOriginalIndex"
static_for<0, ndim_partial_original, 1>{}([&](auto I) {
constexpr auto idim_original = partial_original_dims[I];
mOriginalIndex(idim_original) = partial_original_id[I];
});
// calculate new partial offset on this merged dimension
const index_t old_partial_offset = mPartialOffsets[idim];
mPartialOffsets(idim) =
partial_original_desc.GetOffsetFromMultiIndex(partial_original_id);
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mOffset = (mOffset + mPartialOffsets[idim]) - old_partial_offset;
}).Else([&](auto fwd) {
static_if<PositiveDirection>{}([&](auto) {
mOffset += step_size * fwd(tensor_desc_type{}).GetStride(idim);
}).Else([&](auto) { mOffset -= step_size * fwd(tensor_desc_type{}).GetStride(idim); });
});
}
// T is Array or Sequence
template <class T>
__host__ __device__ type operator+=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
static_for<0, nDim, 1>{}([&](auto idim) {
// compiler should remove dead code path, because step_sizes is known at
// compile time
if(step_sizes[idim] != 0)
{
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, true>{});
}
});
return *this;
}
template <class T>
__host__ __device__ type operator-=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
static_for<0, nDim, 1>{}([&](auto idim) {
// compiler should remove dead code path, because step_sizes is known at
// compile time
if(step_sizes[idim] != 0)
{
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, false>{});
}
});
return *this;
}
template <class T>
__host__ __device__ constexpr type operator+(T step_sizes) const
{
type coord = *this;
coord += step_sizes;
return coord;
}
template <class T>
__host__ __device__ constexpr type operator-(T step_sizes) const
{
type coord = *this;
coord -= step_sizes;
return coord;
}
__host__ __device__ static constexpr index_t RepositionOrigin() { return 0; }
private:
// Allocate register memory for all merged dimensions and normal dimensions.
// However, only those merged dimensions, whose index will be involved in arithmetic
// after the construction of this TensorCoordinate (e.g. when user move a slicing
// window on the merged dimension), will use these register memory.
// Let's hope compiler will optimize away those register memory allocated for normal
// dimensions, and those merged dimensions, that would never be involved in index
// arithmetic after construction of TensorCoordinate.
// TODO: refactor TensorCoordinate, after introducing the concept of "dimensions"
// and simplify implementation of ConstantMergedTensorDescriptor_deprecated, so we don't need to
// count on compiler to optimize away those register memory for us
Array<index_t, nOriginalDim> mOriginalIndex;
Array<index_t, nDim> mPartialOffsets;
// complete offset
index_t mOffset;
};
template <class TensorDesc>
struct TensorCoordinate_deprecated
{
private:
template <class... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated<Ts...>)
{
return NormalTensorCoordinate_deprecated<ConstantTensorDescriptor_deprecated<Ts...>>();
}
template <class... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated<Ts...>)
{
return MergedTensorCoordinate_deprecated<
ConstantMergedTensorDescriptor_deprecated<Ts...>>();
}
public:
using type = decltype(MakeDummyTensorCoordinate(TensorDesc{}));
};
} // namespace ck
#endif
#ifndef CK_TENSOR_COORDINATE_HELPER_HPP
#define CK_TENSOR_COORDINATE_HELPER_HPP
#include "tensor_coordiante_hpp"
namespace ck {
template <typename TensorDesc>
__host__ __device__ constexpr auto
make_tensor_coordinate(TensorDesc, MultiIndex<TensorDesc::GetNumOfDimension()> idx)
{
return typename TensorCoordinate<TensorDesc>::type(idx);
}
} // namespace ck
#endif
#ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "dimension.hpp"
#include "multi_index_transform.hpp"
namespace ck {
// tensor descriptor for "native tensor"
// A "native tensor" is a "true" tensor that can be represented by Lengths and Strides
template <typename... NativeDimensions>
struct NativeTensorDescriptor
{
using type = NativeTensorDescriptor;
static constexpr index_t nDim = sizeof...(NativeDimensions);
static constexpr auto mDimensions = make_tuple(NativeDimensions{}...);
using Index = MultiIndex<nDim>;
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{
return mDimensions.At(Number<IDim>{}).GetLength();
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
{
return mDimensions.At(Number<IDim>{}).GetStride();
}
template <index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
{
return Sequence<GetLength(Number<IDims>{})...>{};
}
template <index_t... IDims>
__host__ __device__ static constexpr auto GetStrides(Sequence<IDims...>)
{
return Sequence<GetStride(Number<IDims>{})...>{};
}
template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
{
return GetLengths(Sequence<IDim, IDims...>{});
}
template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetStrides(Number<IDim>, Number<IDims>...)
{
return GetStrides(Sequence<IDim, IDims...>{});
}
__host__ __device__ static constexpr auto GetLengths()
{
return GetLengths(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
}
__host__ __device__ static constexpr auto GetStrides()
{
return GetStrides(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
}
__host__ __device__ static constexpr index_t GetElementSize()
{
return reduce_on_sequence(GetLengths(), math::multiplies<index_t>{}, Number<1>{});
}
__host__ __device__ static constexpr index_t GetElementSpace()
{
return reduce_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
}
// TODO: this cannot return constepxr because of use of lambda
__host__ __device__ static constexpr index_t CalculateOffset(const Index& idx)
{
index_t offset = 0;
static_for<0, nDim, 1>{}([&](auto idim) { offset += idx[idim] * GetStride(idim); });
return offset;
}
__host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff)
{
index_t offset_diff = 0;
static_for<0, nDim, 1>{}(
[&](auto idim) { offset_diff += idx_diff[idim] * GetStride(idim); });
return offset_diff;
}
template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
{
return true;
}
__host__ __device__ static constexpr auto GetLinearDimensionMask()
{
return typename uniform_sequence_gen<nDim, 1>::type{};
}
__host__ __device__ static constexpr auto GetNonLinearDimensionMask()
{
return typename uniform_sequence_gen<nDim, 0>::type{};
}
__host__ __device__ static constexpr auto GetNonLinearDimensions() { return Sequence<>{}; }
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
{
return Tuple<>{};
}
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidOffset(const Index& /* idx */)
{
return true;
}
};
// Tensor descriptor for "transformed tensor"
template <typename LowTensorDescriptor, // NativeTensorDescriptor or TransformedTensorDescriptor
typename Transforms, // Tuple<MultIndexTransforms...>
typename LowDimensionIds, // Tuple<Sequence<...>>
typename UpDimensionIds> // Tuple<Sequence<...>>
struct TransformedTensorDescriptor
{
using type = TransformedTensorDescriptor;
static constexpr index_t nTransform = Transforms::Size();
struct lambda_merge_sequences
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
{
return merge_sequences(seqs...);
}
};
__host__ __device__ static constexpr auto GetNumOfLowerDimension()
{
// Here, we assume all lower-dimensions are active
// TODO: sanity-check all lower-dimension are indeed active
using duplicated_low_active_dims =
decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{}));
using low_active_dims = typename sequence_unique_sort<duplicated_low_active_dims,
math::less<index_t>,
math::equal<index_t>>::type;
return low_active_dims::Size();
}
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
{
using duplicated_up_active_dims =
decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{}));
using up_active_dims = typename sequence_unique_sort<duplicated_up_active_dims,
math::less<index_t>,
math::equal<index_t>>::type;
return up_active_dims::Size();
}
static constexpr index_t nDimUp = GetNumOfUpperDimension();
static constexpr index_t nDimLow = GetNumOfLowerDimension();
using UpperIndex = MultiIndex<nDimUp>;
using LowerIndex = MultiIndex<nDimLow>;
__host__ __device__ constexpr TransformedTensorDescriptor()
{
static_assert(nTransform == Transforms::Size() && nTransform == LowDimensionIds::Size() &&
nTransform == UpDimensionIds::Size(),
"wrong! # of transformations not the same");
// sanity check:
// LowDimensionIds should include all low-dimensions,
// UpDimensionIds should include all up-dimensions
using mingled_up_dimension_ids =
decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{}));
using sorted_up_dimension_ids =
typename sequence_sort<mingled_up_dimension_ids, math::less<index_t>>::type;
static_assert(sorted_up_dimension_ids::Size() == nDimUp &&
is_valid_sequence_map<sorted_up_dimension_ids>{},
"wrong! UpDimensionIds is not configured correctly");
using mingled_low_dimension_ids =
decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{}));
using sorted_low_dimension_ids =
typename sequence_sort<mingled_low_dimension_ids, math::less<index_t>>::type;
static_assert(sorted_low_dimension_ids::Size() == nDimLow &&
is_valid_sequence_map<sorted_low_dimension_ids>{},
"wrong! LowDimensionIds is not configured correctly");
// TODO: sanity check: while a up-dimension could be associated with multille
// transformation, a low-dimension should be associated with only one transformation
// TODO: sanity-check: GetLowerLengths of each transform should be consistent with lengths
// of lower-tensor-descriptor
}
__host__ __device__ static constexpr auto GetNumOfDimension()
{
return GetNumOfUpperDimension();
}
__host__ __device__ static constexpr auto GetLowerTensorDescriptor()
{
return LowTensorDescriptor{};
}
struct lambda_GetUpperLengths
{
template <typename Transform>
__host__ __device__ constexpr auto operator()(const Transform& tran) const
{
return tran.GetUpperLengths();
}
};
__host__ __device__ static constexpr auto GetUpperLengths()
{
constexpr auto tuple_of_up_lengths =
transform_tuples(lambda_GetUpperLengths{}, Transforms{});
constexpr auto mingled_up_lengths = unpack(lambda_merge_sequences{}, tuple_of_up_lengths);
constexpr auto mingled_up_dimension_ids =
unpack(lambda_merge_sequences{}, UpDimensionIds{});
// TODO: sanity-check mingled_up_dimension_ids contain all upper-dimensions
// TODO: sanity-check mingled_up_lengths have no conflicting upper-length
// sort by upper-dimension-ids
using sort_up_dimension_ids = sequence_unique_sort<decltype(mingled_up_dimension_ids),
math::less<index_t>,
math::equal<index_t>>;
// sanity-check sorted-upper-dimension-ids should be Sequence<0, 1, ... nDimUp-1>
static_assert(is_same<typename sort_up_dimension_ids::type,
typename arithmetic_sequence_gen<0, nDimUp, 1>::type>{},
"wrong! UpDimensionIds is not configured correctly");
constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{};
constexpr auto sorted_up_lengths =
pick_sequence_elements_by_ids(mingled_up_lengths, sorted2unsorted_map);
return sorted_up_lengths;
}
__host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); }
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{
return GetLengths()[IDim];
}
template <index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
{
return Sequence<GetLength(Number<IDims>{})...>{};
}
template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
{
return GetLengths(Sequence<IDim, IDims...>{});
}
__host__ __device__ static constexpr index_t GetElementSize()
{
return reduce_on_sequence(GetLengths(), math::multiplies<index_t>{}, Number<1>{});
}
__host__ __device__ static constexpr index_t GetElementSpace()
{
// TODO: Is this the correct definition for transformed tensor?
return GetLowerTensorDescriptor().GetElementSpace();
}
// TODO: right now return value is not constexpr because use of non-constexpr lambda
__host__ __device__ static constexpr LowerIndex CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
auto idx_low_part = pick_array_element(idx_low, LowDimensionIds{}.At(itran));
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_part = tran.CalculateLowerIndex(to_array(idx_up_part));
});
return idx_low;
}
// TODO: right now return value is not constexpr because use of non-constepxr lambda
__host__ __device__ static constexpr LowerIndex CalculateLowerIndexDiff(
const UpperIndex& idx_up_diff, const UpperIndex& idx_up_old, const LowerIndex& idx_low_old)
{
LowerIndex idx_low_diff;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_diff_part =
pick_array_element(idx_up_diff, UpDimensionIds{}.At(itran));
const auto idx_up_old_part = pick_array_element(idx_up_old, UpDimensionIds{}.At(itran));
const auto idx_low_old_part =
pick_array_element(idx_low_old, LowDimensionIds{}.At(itran));
auto idx_low_diff_part = pick_array_element(idx_low_diff, LowDimensionIds{}.At(itran));
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_diff_part = tran.CalculateLowerIndexDiff(
to_array(idx_up_diff_part), to_array(idx_up_old_part), to_array(idx_low_old_part));
});
return idx_low_diff;
}
__host__ __device__ static constexpr index_t CalculateOffset(const UpperIndex& idx_up)
{
return GetLowerTensorDescriptor().CalculateOffset(CalculateLowerIndex(idx_up));
}
struct lambda_sequence_logical_and
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs...) const
{
return typename sequence_reduce<logical_and<index_t>, Seqs...>::type{};
}
};
template <typename T>
struct lambda_is_true
{
__host__ __device__ constexpr auto operator()(const T& x) const
{
// TODO: remove static_cast once Sequence can take bool as entries
return static_cast<bool>(x) == true;
}
};
struct lambda_get_linear_dimension_mask_of_single_tranform
{
// check only one transform at a time
template <typename Transform, typename LowDimensionId, typename UpDimensionId>
__host__ __device__ constexpr auto
operator()(Transform, LowDimensionId, UpDimensionId) const
{
// judge if transformation is linear
constexpr bool is_linear_transform = Transform::IsLinearTransform();
// judge if all lower dimension are linear
constexpr bool are_all_low_dim_linear = sequence_all_of(
pick_sequence_elements_by_ids(GetLowerTensorDescriptor().GetLinearDimensionMask(),
LowDimensionId{}),
lambda_is_true<index_t>{});
// create linear mask for upper dimensions
constexpr bool are_up_dim_linear = is_linear_transform && are_all_low_dim_linear;
constexpr auto mask_of_up_linear_dims = modify_sequence_elements_by_ids(
typename uniform_sequence_gen<nDimUp, 1>::type{},
typename uniform_sequence_gen<UpDimensionId::Size(), are_up_dim_linear>::type{},
UpDimensionId{});
return mask_of_up_linear_dims;
}
};
// TODO: this is a hack, transform_tuples() doesn't compile, would complain about constexpr
template <typename F, typename X, typename Y, typename Z, index_t... Is>
__host__ __device__ static constexpr auto
dummy_transform_tuples_impl(F f, X x, Y y, Z z, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
}
__host__ __device__ static constexpr auto GetLinearDimensionMask()
{
#if 0
// create tuple of linear dimension masks, for all transformations
// TODO: this doesn't compile, because transform_tuples() complain about constexpr
constexpr auto tuple_of_linear_dimension_mask =
transform_tuples(lambda_get_linear_dimension_mask_of_single_tranform{},
Transforms{},
LowDimensionIds{},
UpDimensionIds{});
#else
// create tuple of linear dimension masks, for all transformations
// TODO: this is a hack
constexpr auto tuple_of_linear_dimension_mask = dummy_transform_tuples_impl(
lambda_get_linear_dimension_mask_of_single_tranform{},
Transforms{},
LowDimensionIds{},
UpDimensionIds{},
typename arithmetic_sequence_gen<0, Transforms::Size(), 1>::type{});
#endif
// reduce tuple of masks into one mask
constexpr auto linear_dimension_mask =
unpack(lambda_sequence_logical_and{}, tuple_of_linear_dimension_mask);
return linear_dimension_mask;
}
__host__ __device__ static constexpr auto GetNonLinearDimensionMask()
{
return GetLinearDimensionMask().Transform(logical_not<index_t>{});
}
template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
{
return GetLinearDimensionMask().At(Number<IDim>{});
}
__host__ __device__ static constexpr auto GetLinearDimensions()
{
constexpr auto linear_dimension_mask = GetLinearDimensionMask();
return pick_sequence_elements_by_mask(
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, linear_dimension_mask);
}
__host__ __device__ static constexpr auto GetNonLinearDimensions()
{
constexpr auto nonlinear_dimension_mask = GetNonLinearDimensionMask();
return pick_sequence_elements_by_mask(
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, nonlinear_dimension_mask);
}
#if 0
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
{
// TODO: not implemented
}
#endif
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up)
{
bool flag = true;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
flag = flag && tran.IsUpperIndexMappedToValidLowerIndex(to_array(idx_up_part));
});
return flag;
}
// Whenever this function is called, it will call CalculateLowerIndex() recursively.
// If you have created a tensor coordinate already, instead of calling this function,
// you should call TensorCoordinate::IsUpperIndexMappedToValidOffset() which would
// be less expensive.
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidOffset(const UpperIndex& idx_up)
{
return IsUpperIndexMappedToValidLowerIndex(idx_up) &&
GetLowerTensorDescriptor().IsUpperIndexMappedToValidOffset(
CalculateLowerIndex(idx_up));
}
};
} // namespace ck
#endif
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
namespace ck {
template <typename Lengths>
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
{
return reverse_inclusive_scan_sequence(
Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
}
template <typename Lengths, index_t Align>
__host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number<Align>)
{
constexpr index_t L_back_align =
Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
return calculate_tensor_strides_packed(
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
}
template <index_t... Lengths, index_t... Strides>
__host__ __device__ constexpr auto make_native_tensor_descriptor(Sequence<Lengths...>,
Sequence<Strides...>)
{
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
}
template <typename Lengths>
__host__ __device__ constexpr auto make_native_tensor_descriptor_packed(Lengths)
{
constexpr auto strides = calculate_tensor_strides_packed(Lengths{});
return make_native_tensor_descriptor(Lengths{}, strides);
}
template <typename Lengths, index_t Align>
__host__ __device__ constexpr auto make_native_tensor_descriptor_aligned(Lengths, Number<Align>)
{
constexpr auto strides = calculate_tensor_strides_aligned(Lengths{}, Number<Align>{});
return make_native_tensor_descriptor(Lengths{}, strides);
}
template <typename LowTensorDescriptor,
typename Transforms,
typename LowDimensionIds,
typename UpDimensionIds>
__host__ __device__ constexpr auto
transform_tensor_descriptor(LowTensorDescriptor, Transforms, LowDimensionIds, UpDimensionIds)
{
return TransformedTensorDescriptor<LowTensorDescriptor,
Transforms,
LowDimensionIds,
UpDimensionIds>{};
}
template <typename LowerTensorDescriptor,
index_t... LowerLengths,
index_t... LowerDimensionIds,
index_t... UpperDimensionIds>
__host__ __device__ constexpr auto
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
Sequence<LowerLengths...>,
Sequence<LowerDimensionIds...>,
Sequence<UpperDimensionIds...>)
{
return TransformedTensorDescriptor<LowerTensorDescriptor,
Tuple<PassThrough<LowerLengths>...>,
Tuple<Sequence<LowerDimensionIds>...>,
Tuple<Sequence<UpperDimensionIds>...>>{};
}
// reorder a NativeTensorDescriptor
template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
{
static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map");
constexpr auto old_desc = NativeTensorDescriptor<Ts...>{};
static_assert(old_desc.GetNumOfDimension() == MapLower2Upper::Size(), "wrong!");
constexpr auto new_lengths = old_desc.GetLengths().ReorderGivenOld2New(MapLower2Upper{});
constexpr auto new_strides = old_desc.GetStrides().ReorderGivenOld2New(MapLower2Upper{});
return make_native_tensor_descriptor(new_lengths, new_strides);
}
// reorder a TransformedTensorDescriptor
template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
{
static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map");
constexpr auto low_desc = TransformedTensorDescriptor<Ts...>{};
static_assert(low_desc.GetNumOfDimension() == MapLower2Upper::Size(), "wrong!");
return reorder_transformed_tensor_descriptor_impl(
low_desc,
low_desc.GetLengths(),
typename arithmetic_sequence_gen<0, low_desc.GetNumOfDimension(), 1>::type{},
MapLower2Upper{});
}
template <typename LowerTensorDescriptor, typename MapUpper2Lower>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_upper2lower(LowerTensorDescriptor, MapUpper2Lower)
{
return reorder_tensor_descriptor_given_lower2upper(
LowerTensorDescriptor{}, typename sequence_map_inverse<MapUpper2Lower>::type{});
}
template <typename Lengths, typename Strides>
__host__ __device__ constexpr bool are_dimensions_unfoldable(Lengths, Strides)
{
static_assert(Lengths::Size() == Strides::Size(), "wrong!");
bool flag = true;
for(index_t i = 0; i < Lengths::Size() - 1; ++i)
{
flag = flag && Strides::At(i) == Strides::At(i + 1) * Lengths::At(i + 1);
}
return flag;
}
// unfold only support NativeTennsorDescriptor, for now
template <index_t FirstUnfoldDim, index_t LastUnfoldDim, typename... Ts>
__host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescriptor<Ts...> desc,
Number<FirstUnfoldDim>,
Number<LastUnfoldDim>)
{
constexpr index_t nDim = desc.GetNumOfDimension();
static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim && FirstUnfoldDim <= LastUnfoldDim,
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!");
// left and right
constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::type{};
constexpr auto middle =
typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::type{};
constexpr auto right = typename arithmetic_sequence_gen<LastUnfoldDim + 1, nDim, 1>::type{};
// sanity-checknfoldable
static_assert(are_dimensions_unfoldable(desc.GetLengths(middle), desc.GetStrides(middle)),
"wrong! not unfoldable");
// unfolded length, stride
constexpr index_t unfold_length =
reduce_on_sequence(desc.GetLengths(middle), math::multiplies<index_t>{}, Number<1>{});
constexpr index_t unfold_stride = desc.GetStride(Number<LastUnfoldDim>{});
// new lengths, strides
constexpr auto new_lengths =
desc.GetLengths(left).PushBack(Number<unfold_length>{}).PushBack(desc.GetLengths(right));
constexpr auto new_strides =
desc.GetStrides(left).PushBack(Number<unfold_stride>{}).PushBack(desc.GetStrides(right));
return make_native_tensor_descriptor(new_lengths, new_strides);
}
// a cluster map 1d index to N-d index
template <typename Lengths, typename ArrangeOrder>
struct ClusterDescriptor
{
static constexpr index_t nDim = Lengths::Size();
static constexpr auto mDesc = transform_tensor_descriptor(
make_native_tensor_descriptor_packed(Lengths{}),
make_tuple(Merge<decltype(Lengths::ReorderGivenNew2Old(ArrangeOrder{}))>{}),
make_tuple(ArrangeOrder{}),
make_tuple(Sequence<0>{}));
__host__ __device__ constexpr ClusterDescriptor()
{
static_assert(Lengths::Size() == nDim && ArrangeOrder::Size() == nDim,
"wrong! size not the same");
static_assert(is_valid_sequence_map<ArrangeOrder>{}, "wrong! ArrangeOrder is wrong");
}
__host__ __device__ static constexpr index_t GetElementSize() { return mDesc.GetElementSize(); }
__host__ __device__ static constexpr auto CalculateClusterIndex(index_t idx_1d)
{
return mDesc.CalculateLowerIndex(MultiIndex<1>{idx_1d});
}
};
template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor(
Lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
{
return ClusterDescriptor<Lengths, decltype(order)>{};
}
} // namespace ck
#endif
...@@ -95,9 +95,21 @@ __device__ void WaveWiseGemmMx64(const FloatA* const __restrict__ p_a_wave, ...@@ -95,9 +95,21 @@ __device__ void WaveWiseGemmMx64(const FloatA* const __restrict__ p_a_wave,
(mfma_info::group_size * mfma_info::num_blks_wave) + (mfma_info::group_size * mfma_info::num_blks_wave) +
a_off; // A is transposed a_off; // A is transposed
index_t bindex = b_off + lane_b + n * mfma_info::num_threads_blk; index_t bindex = b_off + lane_b + n * mfma_info::num_threads_blk;
p_c_thread[m + n * output_m + b * output_m * mfma_info::num_blks_wave] += // p_c_thread[m + n * output_m + b * output_m * mfma_info::num_blks_wave] +=
math::inner_product_with_conversion<FloatC>{}(p_a_wave[aindex], // math::inner_product_with_conversion<FloatC>{}(p_a_wave[aindex],
p_b_wave[bindex]); // p_b_wave[bindex]);
index_t cindex = m + n * output_m + b * output_m * mfma_info::num_blks_wave;
if(blockIdx.x*blockDim.x + threadIdx.x == 0 && cindex == 0)
{
printf("Run p_c[%d] = %f, p_a[%d] = %f, p_b[%d] = %f\n",
cindex,
p_c_thread[cindex],
aindex,
p_a_wave[aindex],
bindex,
p_b_wave[bindex]);
p_c_thread[cindex+k] = p_a_wave[aindex];
}
} }
} }
} }
...@@ -251,6 +263,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops ...@@ -251,6 +263,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
constexpr index_t N = BlockMatrixB::NCol(); constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t K = BlockMatrixA::NRow(); constexpr index_t K = BlockMatrixA::NRow();
if(blockIdx.x*blockDim.x + threadIdx.x == 0)
printf("Run M %d, N %d, K %d\n", M, N, K);
// static_if<EnableXdlops>{}([&](auto) { // static_if<EnableXdlops>{}([&](auto) {
// WaveWiseGemmMx64_xdlops<M, // WaveWiseGemmMx64_xdlops<M,
// N, // N,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment