Unverified Commit 5c7cec11 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Code clean up (#20)



* tuning para,

* testing on v100

* add fp16

* remove deprecated tensor descriptor

* sync with miopen

* update build script
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 7d09790a
...@@ -8,53 +8,9 @@ ...@@ -8,53 +8,9 @@
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
#include "convolution_common.hpp"
namespace ck { namespace ck {
template <ConvolutionDirection>
struct make_wei_e_k_global_desc_v4r1;
template <>
struct make_wei_e_k_global_desc_v4r1<ConvolutionDirection::Forward>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
return reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(WeiDesc{}, I1, I3), Sequence<1, 0>{});
}
};
template <>
struct make_wei_e_k_global_desc_v4r1<ConvolutionDirection::BackwardWeight>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto wei_k_c_y_x_global_desc = WeiDesc{};
constexpr index_t K = wei_k_c_y_x_global_desc.GetLength(I0);
constexpr index_t C = wei_k_c_y_x_global_desc.GetLength(I1);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
return transform_tensor_descriptor(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3),
make_tuple(Merge<Sequence<C, Y * X>>{}, PassThrough<K>{}),
make_tuple(Sequence<1, 2>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
};
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
typename Float, typename Float,
...@@ -66,18 +22,17 @@ template <index_t GridSize, ...@@ -66,18 +22,17 @@ template <index_t GridSize,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename LeftPads,
typename RightPads, typename RightPads,
ConvolutionDirection ConvDirection,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
index_t GemmNRepeat, index_t GemmNRepeat,
index_t GemmMPerThreadSubC, index_t GemmMPerThread,
index_t GemmNPerThreadSubC, index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster, index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA, index_t GemmDataPerReadA,
index_t GemmDataPerReadB, index_t GemmDataPerReadB,
typename InBlockCopySubLengths_E_N1_B_N2, typename InBlockCopySubLengths_E_N1_B_N2,
...@@ -107,19 +62,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -107,19 +62,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
static_assert(ConvDirection == ConvolutionDirection::Forward ||
ConvDirection == ConvolutionDirection::BackwardWeight,
"wrong! this kernel only support convolution forward and backward-weight");
// this is a mess // this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters // TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr index_t N1 = GemmNRepeat; constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThread;
static_assert((N1 * N2 * BPerBlock) % static_assert(
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == (N1 * N2 * BPerBlock) % (GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
0, "wrong!");
"wrong!");
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
...@@ -240,7 +190,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -240,7 +190,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// It is constructed differently, depending on whether forward or backward weight // It is constructed differently, depending on whether forward or backward weight
// convolution // convolution
constexpr auto wei_e_k_global_desc = constexpr auto wei_e_k_global_desc =
make_wei_e_k_global_desc_v4r1<ConvDirection>{}(wei_k_c_y_x_global_desc); transform_tensor_descriptor(unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3),
make_tuple(Merge<Sequence<C, Y * X>>{}, PassThrough<K>{}),
make_tuple(Sequence<1, 2>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// block tensor in LDS memory, dst of blockwise copy // block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -290,30 +243,29 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -290,30 +243,29 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
in_e_n1_b_n2_block_desc.GetStride(I0)); in_e_n1_b_n2_block_desc.GetStride(I0));
// sanity check // sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == static_assert(KPerBlock % (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0,
0,
"wrong!"); "wrong!");
constexpr index_t GemmMRepeat = constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); KPerBlock / (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{}); Number<GemmMRepeat * GemmMPerThread>{}, Number<GemmNRepeat * GemmNPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
decltype(a_e_k_block_mtx_desc), decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc), decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k1_n1n2_thread_mtx_desc), decltype(c_k0k1_n1n2_thread_mtx_desc),
GemmMPerThreadSubC, GemmMPerThread,
GemmNPerThreadSubC, GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB>{}; GemmDataPerReadB>{};
...@@ -432,13 +384,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -432,13 +384,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// copy output: register to global memory // copy output: register to global memory
{ {
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t K1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t K0 = K / K1; constexpr index_t K0 = K / K1;
// define output tensor descriptor for threadwise copy // define output tensor descriptor for threadwise copy
// thread output tensor, src of threadwise copy // thread output tensor, src of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed( constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{}); Sequence<GemmMRepeat, GemmMPerThread, N1, 1, N2>{});
// global output tensor // global output tensor
constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor( constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor(
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "convolution_common.hpp"
namespace ck {
template <ConvolutionDirection>
struct make_wei_e_k_global_desc_v4r1_deprecated;
template <>
struct make_wei_e_k_global_desc_v4r1_deprecated<ConvolutionDirection::Forward>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
return WeiDesc::Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
}
};
template <>
struct make_wei_e_k_global_desc_v4r1_deprecated<ConvolutionDirection::BackwardWeight>
{
template <typename WeiDesc>
__device__ constexpr auto operator()(WeiDesc) const
{
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
return make_ConstantMergedTensorDescriptor(
WeiDesc::Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{});
}
};
// define B = merge(N0, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
ConvolutionDirection ConvDirection,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmNRepeat,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_N1_B_N2,
class InBlockCopyClusterLengths_E_N1_B_N2,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::Generic>{};
constexpr auto global_address_space =
integral_constant<AddressSpace, AddressSpace::Global>{};
static_assert(ConvDirection == ConvolutionDirection::Forward ||
ConvDirection == ConvolutionDirection::BackwardWeight,
"wrong! this kernel only support convolution forward and backward-weight");
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0,
"wrong!");
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2);
constexpr index_t B = N0 * Ho * Wo;
constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load
static_assert(
(Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! alignment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc),
Sequence<0, 1, 2>{},
Sequence<4>{},
Sequence<3, 6, 7>{},
Sequence<5>{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not satisfied");
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2_deprecated<
BlockSize,
decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
2,
3,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor
// Iensor descriptor in device memory, src of blockwise copy
// It is constructed differently, depending on whether forward or backward weight
// convolution
constexpr auto wei_e_k_global_desc =
make_wei_e_k_global_desc_v4r1_deprecated<ConvDirection>{}(wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc =
make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k1_n1n2_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
AccDataType p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_n1n2_thread_mtx_desc, p_out_thread);
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(
p_in_global, p_in_block_double, global_address_space, generic_address_space);
blockwise_wei_copy.Run(
p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
// even iteration
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer(
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
}
}
// copy output: register to global memory
{
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register, src of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
// output memory layout descriptor in device memory
constexpr auto out_n0_n1_n2_k0_k1_h_w_global_mem_desc =
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}).Fold(I0, Number<N1>{}, Number<N2>{});
// output merged global tensor descriptor, dst of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_global_merged_desc =
make_ConstantMergedTensorDescriptor(out_n0_n1_n2_k0_k1_h_w_global_mem_desc,
Sequence<3>{},
Sequence<4>{},
Sequence<1>{},
Sequence<0, 5, 6>{},
Sequence<2>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
decltype(out_k0_k1_n1_b_n2_thread_mem_desc),
decltype(out_k0_k1_n1_b_n2_global_merged_desc),
decltype(out_k0_k1_n1_b_n2_thread_mem_desc.GetLengths()),
arithmetic_sequence_gen<0, 5, 1>::type,
arithmetic_sequence_gen<0, 5, 1>::type,
3,
3,
1,
1>({0, 0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
0,
b_thread_data_on_global,
0})
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
}
}
};
} // namespace ck
#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
index_t N1,
index_t N2,
index_t Ho1,
index_t Ho2,
index_t Wo1,
index_t Wo2,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2,
class InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_W2,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 * Ho2 * Wo2 == GemmNPerThreadSubC, "wrong!");
static_assert((N1 * Ho1 * Wo1 * BPerBlock * N2 * Ho2 * Wo2) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I5 = Number<5>{};
constexpr auto I7 = Number<7>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N1 * Ho1 * Wo1;
static_assert(N % (N1 * N2) == 0 && Ho % (Ho1 * Ho2) == 0 && Wo % (Wo1 * Wo2) == 0,
"wrong!");
constexpr index_t N0 = N / (N1 * N2);
constexpr index_t Ho0 = Ho / (Ho1 * Ho2);
constexpr index_t Wo0 = Wo / (Wo1 * Wo2);
static_assert((X == 1 || ConvDilationW % InBlockCopyDataPerAccess_W2 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho0, Ho1, Ho2, Wo0, Wo1, Wo2]
constexpr auto in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc =
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{})
.Fold(I2, Number<Wo1>{}, Number<Wo2>{})
.Fold(I1, Number<Ho1>{}, Number<Ho2>{})
.Fold(I0, Number<N1>{}, Number<N2>{});
constexpr auto in_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_global_desc =
in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc.ReorderGivenNew2Old(
Sequence<0, 3, 6, 1, 4, 7, 2, 5, 8>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc =
make_ConstantMergedTensorDescriptor(
in_c_y_x_global_desc.Embed(in_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_global_desc),
Sequence<0, 1, 2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6, 7, 8>{},
Sequence<9>{},
Sequence<10>{},
Sequence<11>{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc =
make_ConstantTensorDescriptor_packed(
Sequence<EPerBlock, N0, Ho0, Wo0, BPerBlock, N2, Ho2, Wo2>{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockSize,
Float,
decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc),
decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc),
decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2,
InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopyDataPerAccess_W2,
InBlockCopyDataPerAccess_W2>({0, 0, 0, 0, b_block_data_on_global, 0, 0, 0},
{0, 0, 0, 0, 0, 0, 0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
Float,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.GetStrides()[3] % GemmDataPerReadB ==
0,
"GemmDataPerReadB alignment requirement is not satisfied");
constexpr auto b_e_n0ho0wo0bn2ho2wo2_block_mtx_desc =
make_ConstantMatrixDescriptor(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.Unfold(I1, I7));
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<GemmMRepeat * GemmMPerThreadSubC>{},
Number<N0 * Ho0 * Wo0 * N2 * Ho2 * Wo2>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_n0ho0wo0bn2ho2wo2_block_mtx_desc),
decltype(c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_W2,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space = math::integer_least_multiple(
in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
Float p_out_thread[c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc, p_out_thread);
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
// copy output: register to global memory
{
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr auto out_k0_k1_k2_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_thread_mem_desc =
make_ConstantTensorDescriptor_packed(
Sequence<KPerBlock / (K1 * K2), 1, K2, N0, Ho0, Wo0, 1, 1, 1, N2, Ho2, Wo2>{});
// output tensor descriptor in register, src of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc =
out_k0_k1_k2_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_thread_mem_desc.ReorderGivenNew2Old(
Sequence<3, 6, 9, 0, 1, 2, 4, 7, 10, 5, 8, 11>{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc =
out_n_k_h_w_global_desc.Fold(I3, Sequence<Wo1, Wo2>{})
.Fold(I2, Sequence<Ho1, Ho2>{})
.Fold(I1, Sequence<K1, K2>{})
.Fold(I0, Sequence<N1, N2>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / (N2 * Ho2 * Wo2);
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc =
make_ConstantMergedTensorDescriptor(
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc.Unfold(I3, I5),
Sequence<3>{},
Sequence<0>{},
Sequence<4>{},
Sequence<7>{},
Sequence<1, 5, 8>{},
Sequence<2>{},
Sequence<6>{},
Sequence<9>{});
// origin of dst in device memory
Float* p_out_thread_on_global =
p_out_global +
out_k_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, 0, 0, b_thread_data_on_global, 0, 0, 0);
threadwise_generic_tensor_slice_copy_v1(
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc,
p_out_thread,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc,
p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 12, 1>::type{},
Number<1>{});
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
index_t N0,
index_t N1,
index_t N2,
index_t Ho0,
index_t Ho1,
index_t Ho2,
index_t Wo0,
index_t Wo1,
index_t Wo2,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
class InBlockCopyClusterLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_W2,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 * Ho2 * Wo2 == GemmNPerThreadSubC, "wrong!");
static_assert((N1 * Ho1 * Wo1 * BPerBlock * N2 * Ho2 * Wo2) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I5 = Number<5>{};
constexpr auto I7 = Number<7>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N0 * Ho0 * Wo0;
static_assert(N == N0 * N1 * N2 && Ho == Ho0 * Ho1 * Ho2 && Wo == Wo0 * Wo1 * Wo2,
"wrong!");
static_assert((X == 1 || ConvDilationW % InBlockCopyDataPerAccess_W2 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho0, Ho1, Ho2, Wo0, Wo1, Wo2]
constexpr auto in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc =
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{})
.Fold(I2, Number<Wo1>{}, Number<Wo2>{})
.Fold(I1, Number<Ho1>{}, Number<Ho2>{})
.Fold(I0, Number<N1>{}, Number<N2>{});
constexpr auto in_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_global_desc =
in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc.ReorderGivenNew2Old(
Sequence<1, 4, 7, 0, 3, 6, 2, 5, 8>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc =
make_ConstantMergedTensorDescriptor(
in_c_y_x_global_desc.Embed(in_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_global_desc),
Sequence<0, 1, 2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6, 7, 8>{},
Sequence<9>{},
Sequence<10>{},
Sequence<11>{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc =
make_ConstantTensorDescriptor_packed(
Sequence<EPerBlock, N1, Ho1, Wo1, BPerBlock, N2, Ho2, Wo2>{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
BlockSize,
Float,
decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc),
decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc),
decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
InBlockCopyClusterLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopyDataPerAccess_W2,
InBlockCopyDataPerAccess_W2>({0, 0, 0, 0, b_block_data_on_global, 0, 0, 0},
{0, 0, 0, 0, 0, 0, 0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
Float,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
#if 0
if(get_block_1d_id() == 0)
{
printf("id (%d %d), in offset: %d %d, wei offset %d %d\n",
get_block_1d_id(),
get_thread_local_1d_id(),
blockwise_in_copy.mThreadSrcOffset,
blockwise_in_copy.mThreadDstOffset,
blockwise_wei_copy.mThreadSrcOffset,
blockwise_wei_copy.mThreadDstOffset);
}
#endif
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.GetStrides()[3] % GemmDataPerReadB ==
0,
"GemmDataPerReadB alignment requirement is not satisfied");
constexpr auto b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc =
make_ConstantMatrixDescriptor(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.Unfold(I1, I7));
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<GemmMRepeat * GemmMPerThreadSubC>{},
Number<N1 * Ho1 * Wo1 * N2 * Ho2 * Wo2>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc),
decltype(c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_W2,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space = math::integer_least_multiple(
in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
Float p_out_thread[c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc, p_out_thread);
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
#if 0
if(get_block_1d_id() == 0)
{
printf("tid (%d %d), %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
p_wei_register_buffer[0],
p_wei_register_buffer[1],
p_wei_register_buffer[2],
p_wei_register_buffer[3]);
}
#endif
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
// copy output: register to global memory
{
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr auto out_k0_k1_k2_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_thread_mem_desc =
make_ConstantTensorDescriptor_packed(
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, Ho1, Wo1, 1, 1, 1, N2, Ho2, Wo2>{});
// output tensor descriptor in register, src of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc =
out_k0_k1_k2_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_thread_mem_desc.ReorderGivenNew2Old(
Sequence<6, 3, 9, 0, 1, 2, 7, 4, 10, 8, 5, 11>{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc =
out_n_k_h_w_global_desc.Fold(I3, Sequence<Wo1, Wo2>{})
.Fold(I2, Sequence<Ho1, Ho2>{})
.Fold(I1, Sequence<K1, K2>{})
.Fold(I0, Sequence<N1, N2>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / (N2 * Ho2 * Wo2);
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc =
make_ConstantMergedTensorDescriptor(
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc.Unfold(I3, I5),
Sequence<3>{},
Sequence<1>{},
Sequence<5>{},
Sequence<8>{},
Sequence<0, 4, 7>{},
Sequence<2>{},
Sequence<6>{},
Sequence<9>{});
// origin of dst in device memory
Float* p_out_thread_on_global =
p_out_global +
out_k_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, 0, 0, b_thread_data_on_global, 0, 0, 0);
threadwise_generic_tensor_slice_copy_v1(
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc,
p_out_thread,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc,
p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 12, 1>::type{},
Number<1>{});
}
}
};
} // namespace ck
#endif
...@@ -75,6 +75,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -75,6 +75,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1]; constexpr index_t ConvDilationW = ConvDilations{}[1];
#if 0
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) && static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0) && (X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0) &&
...@@ -82,9 +83,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -82,9 +83,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
InRightPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0, InRightPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0,
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
#endif
// weight tensor // weight tensor
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower( constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{}); unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
// input tensor // input tensor
...@@ -108,14 +110,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -108,14 +110,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_e_b_global_desc = transform_tensor_descriptor( constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}), make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor // output tensor
constexpr auto out_k_b_global_desc = constexpr auto out_gemmk_gemmn_global_desc =
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3), transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}), make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
...@@ -127,9 +129,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -127,9 +129,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
BlockSize, BlockSize,
Float, Float,
AccFloat, AccFloat,
decltype(wei_e_k_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(in_e_b_global_desc), decltype(in_gemmm_gemmn_global_desc),
decltype(out_k_b_global_desc), decltype(out_gemmk_gemmn_global_desc),
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
...@@ -157,7 +159,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -157,7 +159,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
1, 1,
GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN, GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<0, 1, 2, 3>, Sequence<2, 3, 0, 1>,
3, 3,
GemmCThreadCopyDstDataPerWrite_GemmN1>{}; GemmCThreadCopyDstDataPerWrite_GemmN1>{};
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_DEPRECATRD_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_DEPRECATRD_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
namespace ck {
// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_B,
class InBlockCopyClusterLengths_E_B,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B>
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I5 = Number<5>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N, Ho, Wo]
constexpr auto in_n_ho_wo_global_desc =
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, B], src of blockwise copy
constexpr auto in_e_b_global_desc =
make_ConstantMergedTensorDescriptor(in_c_y_x_global_desc.Embed(in_n_ho_wo_global_desc),
Sequence<0, 1, 2>{},
Sequence<3, 4, 5>{});
// memory layout descriptor in LDS [E, B], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_b_block_desc =
make_ConstantTensorDescriptor_packed(Sequence<EPerBlock, BPerBlock>{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc),
decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
1,
1,
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
// sanity check
static_assert(
KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
BPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
constexpr index_t GemmNRepeat =
BPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k1_b0b1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_b_block_mtx_desc),
decltype(c_k0k1_b0b1_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
Float p_out_thread[c_k0k1_b0b1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.template Run<Float, AddressSpace::Global>(p_in_global,
p_in_block_double);
blockwise_wei_copy.template Run<Float, AddressSpace::Global>(p_wei_global,
p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
// even iteration
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
p_in_global, p_in_thread_buffer);
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
// copy output: register to global memory
{
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
// define tensor descriptor for threadwise copy
// output global descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_b_global_desc = make_ConstantMergedTensorDescriptor(
out_n_k_h_w_global_desc, Sequence<1>{}, Sequence<0, 2, 3>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
// This is a hack, because slicing a merged dimension is not supported yet.
// This should be replaced with logic above, once slicing a merged dimension support
// become available
// dst descriptor
constexpr auto out_k0_k1_b_global_desc =
make_ConstantMergedTensorDescriptor(out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}),
Sequence<1>{},
Sequence<2>{},
Sequence<0, 3, 4>{});
// src descriptor
constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat * GemmNPerThreadSubC>{});
using OutThreadCopySliceLengths =
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
decltype(out_k0_k1_b_thread_desc),
decltype(out_k0_k1_b_global_desc),
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 3, 1>::type,
arithmetic_sequence_gen<0, 3, 1>::type,
2,
2,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>({0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global});
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
{
threadwise_out_copy
.template Run<Float, AddressSpace::Generic, AddressSpace::Global>(p_out_thread,
p_out_global);
threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True);
threadwise_out_copy.MoveDstSliceWindow(Sequence<0, 0, B1>{}, True);
}
}
}
};
} // namespace ck
#endif
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP #define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
namespace ck { namespace ck {
...@@ -58,18 +57,6 @@ __host__ __device__ constexpr auto ...@@ -58,18 +57,6 @@ __host__ __device__ constexpr auto
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{}; return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
} }
template <typename... Ts>
__host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>)
{
using TDesc = ConstantTensorDescriptor_deprecated<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
static_assert(TDesc::GetStrides()[1] == 1, "wrong");
return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
TDesc::GetLengths()[1],
TDesc::GetStrides()[0]>{};
}
template <typename... Ts> template <typename... Ts>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(NativeTensorDescriptor<Ts...>) __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(NativeTensorDescriptor<Ts...>)
{ {
......
#ifndef CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#define CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
namespace ck {
// OriginalTensorDesc : ConstantTensorDescriptor_deprecated<...>
// it's the tensor whose dimensions are to be merged
// OriginalDimMergeSeqs : Sequence<...>...
// each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
struct ConstantMergedTensorDescriptor_deprecated
{
using Type = ConstantMergedTensorDescriptor_deprecated;
static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};
static constexpr index_t nDim = sizeof...(OriginalDimMergeSeqs);
static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
__host__ __device__ constexpr ConstantMergedTensorDescriptor_deprecated()
{
static_assert(nDim <= nOriginalDim, "wrong!");
// TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
// OriginalTensorDesc::nDim number of dimensions
// TODO: check OriginalDimMergeSeqs contains all original dimensions
// TODO: check there is no duplication in OriginalDimMergeSeqs
}
__host__ __device__ static constexpr auto GetOriginalTensorDescriptor()
{
return OriginalTensorDesc{};
}
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
template <index_t IDim>
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
{
return std::get<IDim>(mOriginalDimMergeSeqs);
}
template <index_t IDim>
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
{
return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{
constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
{
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
"wrong! stride of a merged dimension is undefined");
constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
return OriginalTensorDesc::GetStride(Number<idim_original>{});
}
// this is a hack to return the stride of the last original dimension of a merged dimension
// TODO: refactor this once the concept of "dimension" is used
template <index_t IDim>
__host__ __device__ static constexpr auto GetLastOriginalDimensionStride(Number<IDim>)
{
constexpr auto idim_last_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
return OriginalTensorDesc::GetStride(Number<idim_last_original>{});
}
__host__ __device__ static constexpr auto GetLengths()
{
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
}
__host__ __device__ static constexpr auto GetElementSize()
{
return OriginalTensorDesc::GetElementSize();
}
template <class OriginalDimsPartial>
struct lambda_1_GetOriginalMultiIndexFromMultiIndex
{
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial;
Array<index_t, nOriginalDim>& original_multi_id;
__host__ __device__ constexpr lambda_1_GetOriginalMultiIndexFromMultiIndex(
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial_,
Array<index_t, nOriginalDim>& original_multi_id_)
: original_multi_id_partial(original_multi_id_partial_),
original_multi_id(original_multi_id_)
{
}
template <index_t I>
__host__ __device__ constexpr void operator()(Number<I>) const
{
constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{});
index_t itmp = original_multi_id_partial[I];
original_multi_id(idim_original) = itmp;
}
};
struct lambda_0_GetOriginalMultiIndexFromMultiIndex
{
const Array<index_t, nDim>& multi_id;
Array<index_t, nOriginalDim>& original_multi_id;
__host__ __device__ constexpr lambda_0_GetOriginalMultiIndexFromMultiIndex(
const Array<index_t, nDim>& multi_id_, Array<index_t, nOriginalDim>& original_multi_id_)
: multi_id(multi_id_), original_multi_id(original_multi_id_)
{
}
template <index_t IDim>
__host__ __device__ constexpr void operator()(Number<IDim>) const
{
constexpr auto original_dims_partial = std::get<IDim>(Type::mOriginalDimMergeSeqs);
// get partial original-multi-id corresponding to this merged dimension
const auto original_multi_id_partial =
OriginalTensorDesc::Extract(original_dims_partial)
.GetMultiIndexFrom1dIndex(multi_id[IDim]);
static_for<0, original_dims_partial.GetSize(), 1>{}(
lambda_1_GetOriginalMultiIndexFromMultiIndex<decltype(original_dims_partial)>(
original_multi_id_partial, original_multi_id));
}
};
// return type is Array<...>
__host__ __device__ static constexpr auto
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
{
Array<index_t, nOriginalDim> original_multi_id;
static_for<0, nDim, 1>{}(
lambda_0_GetOriginalMultiIndexFromMultiIndex(multi_id, original_multi_id));
return original_multi_id;
}
template <index_t... Is>
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
{
constexpr auto multi_id = sequence2array(Sequence<Is...>{});
constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
}
__host__ __device__ static constexpr index_t
GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
{
auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
}
template <class... Is>
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
{
return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
}
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
{
constexpr auto packed_desc = make_ConstantTensorDescriptor_packed(GetLengths());
return packed_desc.GetMultiIndexFrom1dIndex(id);
}
__host__ __device__ static constexpr auto Pack()
{
constexpr auto lengths = GetLengths();
constexpr auto strides = calculate_tensor_strides_packed(lengths);
return ConstantTensorDescriptor_deprecated<decltype(lengths), decltype(strides)>{};
}
};
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
OriginalDimMergeSeqs...)
{
return ConstantMergedTensorDescriptor_deprecated<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
}
template <class TDesc>
__host__ __device__ void print_ConstantMergedTensorDescriptor(const char* s, TDesc)
{
print_ConstantTensorDescriptor(s, TDesc::GetOriginalTensorDescriptor());
}
} // namespace ck
#endif
#ifndef CK_CONSTANT_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#define CK_CONSTANT_TENSOR_DESCRIPTOR_DEPRECATED_HPP
#include "common_header.hpp"
namespace ck {
template <class Lengths>
__host__ __device__ constexpr auto calculate_tensor_strides_packed_deprecated(Lengths)
{
return reverse_inclusive_scan_sequence(
Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
}
template <class Lengths, index_t Align>
__host__ __device__ constexpr auto calculate_tensor_strides_aligned_old(Lengths, Number<Align>)
{
constexpr index_t L_back_align =
Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
return calculate_tensor_strides_packed_deprecated(
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
}
template <class Lengths, class Strides>
struct ConstantTensorDescriptor_deprecated
{
using Type = ConstantTensorDescriptor_deprecated;
static constexpr index_t nDim = Lengths::GetSize();
__host__ __device__ constexpr ConstantTensorDescriptor_deprecated()
{
static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
}
__host__ __device__ static constexpr auto GetOriginalTensorDescriptor() { return Type{}; }
template <index_t IDim>
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
{
return Sequence<IDim>{};
}
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
__host__ __device__ static constexpr auto GetStrides() { return Strides{}; }
__host__ __device__ static constexpr auto GetLength(index_t IDim) { return Lengths{}[IDim]; }
__host__ __device__ static constexpr auto GetStride(index_t IDim) { return Strides{}[IDim]; }
struct lambda_AreDimensionsContinuous
{
bool& is_continuous;
__host__ __device__ constexpr lambda_AreDimensionsContinuous(bool& is_continuous_)
: is_continuous(is_continuous_)
{
}
template <index_t IDim_>
__host__ __device__ constexpr void operator()(Number<IDim_>) const
{
constexpr auto IDim = Number<IDim_>{};
constexpr auto IDim_p1 = Number<IDim_ + 1>{};
is_continuous =
is_continuous && (GetStride(IDim) >= GetStride(IDim_p1) &&
GetStride(IDim) == GetStride(IDim_p1) * GetLength(IDim_p1));
}
};
__host__ __device__ static constexpr bool AreDimensionsContinuous()
{
bool is_continuous = true;
static_for<0, nDim - 1, 1>{}(lambda_AreDimensionsContinuous(is_continuous));
return is_continuous;
}
__host__ __device__ static constexpr bool IsPackedTensor()
{
return AreDimensionsContinuous() && GetStride(Number<nDim - 1>{}) == 1;
}
template <class T>
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(T)
{
return false;
}
__host__ __device__ static constexpr auto GetElementSize()
{
return Number<reduce_on_sequence(Lengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
}
__host__ __device__ static constexpr auto GetElementSpace()
{
constexpr index_t element_space_unaligned = reduce_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
return Number<element_space_unaligned>{};
}
// emulate constexpr lambda
template <index_t NSize>
struct lambda_GetOffsetFromMultiIndex
{
Array<index_t, NSize>& multi_id;
index_t& offset;
__host__
__device__ constexpr lambda_GetOffsetFromMultiIndex(Array<index_t, NSize>& multi_id_,
index_t& offset_)
: multi_id(multi_id_), offset(offset_)
{
}
template <class X>
__host__ __device__ constexpr void operator()(X IDim) const
{
offset += multi_id[IDim] * Type::GetStride(IDim);
}
};
template <index_t NSize>
__host__ __device__ static constexpr index_t
GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
{
static_assert(NSize == nDim, "wrong! Dimension not consistent");
index_t offset = 0;
static_for<0, nDim, 1>{}(lambda_GetOffsetFromMultiIndex<NSize>(multi_id, offset));
return offset;
}
template <class... Is>
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
{
return GetOffsetFromMultiIndex(Array<index_t, sizeof...(Is)>{is...});
}
template <index_t... Is>
__host__ __device__ static constexpr auto GetOffsetFromMultiIndex(Sequence<Is...>)
{
static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");
constexpr auto multi_id = Sequence<Is...>{};
return Number<reduce_on_sequence(
multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{})>{};
}
// emulate constexpr lambda
template <class PackedStrides>
struct lambda_GetMultiIndexFrom1dIndex
{
index_t& id;
Array<index_t, nDim>& multi_id;
__host__
__device__ constexpr lambda_GetMultiIndexFrom1dIndex(index_t& id_,
Array<index_t, nDim>& multi_id_)
: id(id_), multi_id(multi_id_)
{
}
template <class IDim_>
__host__ __device__ constexpr void operator()(IDim_) const
{
constexpr auto IDim = IDim_{};
constexpr index_t stride = PackedStrides::Get(IDim);
multi_id(IDim) = id / stride;
id -= multi_id[IDim] * stride;
}
};
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
{
Array<index_t, nDim> multi_id;
using PackedStrides = decltype(calculate_tensor_strides_packed_deprecated(GetLengths()));
// calculate index in each of the dimensions in the order of their dimension
static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id));
multi_id(Number<nDim - 1>{}) = id / PackedStrides::Get(Number<nDim - 1>{});
return multi_id;
}
__host__ __device__ static constexpr auto
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
{
return multi_id;
}
// This function doesn't do carry check on the highest dimension for positive stepping (or
// borrow check on the highest dimension for negative stepping) , for performance reason. It is
// the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound on the
// highest dimension for positive stepping (or on the lowest dimension for negative stepping)
template <bool PositiveDirection>
__host__ __device__ static Array<index_t, nDim>
UpdateMultiIndexGivenStepSizeOf1dIndex(Array<index_t, nDim> old_multi_id,
index_t step_size_of_1d_index,
integral_constant<bool, PositiveDirection>)
{
Array<index_t, nDim> new_multi_id;
const auto step_sizes = GetMultiIndexFrom1dIndex(step_size_of_1d_index);
static_if<PositiveDirection>{}([&](auto) {
new_multi_id = old_multi_id + step_sizes;
bool carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDim, 1>{}([&](auto IDimReverse) {
constexpr index_t idim = nDim - 1 - IDimReverse;
constexpr auto IDim = Number<idim>{};
if(carry)
{
++new_multi_id(idim);
}
carry = false;
if(new_multi_id[idim] >= GetLength(IDim))
{
new_multi_id(idim) -= GetLength(IDim);
carry = true;
}
});
}).Else([&](auto) {
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
new_multi_id = old_multi_id + (GetLengths() - step_sizes);
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDim, 1>{}([&](auto IDimReverse) {
constexpr index_t idim = nDim - 1 - IDimReverse;
constexpr auto IDim = Number<idim>{};
if(borrow)
{
--new_multi_id(idim);
}
borrow = false;
if(new_multi_id[idim] < GetLength(IDim))
{
new_multi_id(idim) += GetLength(IDim);
borrow = true;
}
});
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
new_multi_id = new_multi_id - GetLengths();
});
return new_multi_id;
}
template <index_t... IDims>
__host__ __device__ static constexpr auto Extract(Number<IDims>... extract_dims)
{
static_assert(sizeof...(IDims) <= GetNumOfDimension(),
"wrong! too many number of dimensions to be extracted");
using extract_lengths = decltype(Lengths::Extract(extract_dims...));
using extract_strides = decltype(Strides::Extract(extract_dims...));
return ConstantTensorDescriptor_deprecated<extract_lengths, extract_strides>{};
}
template <index_t... IDims>
__host__ __device__ static constexpr auto Extract(Sequence<IDims...>)
{
return Extract(Number<IDims>{}...);
}
template <class... Ts>
__host__ __device__ static constexpr auto Embed(ConstantTensorDescriptor_deprecated<Ts...>)
{
using leaf_tensor = ConstantTensorDescriptor_deprecated<Ts...>;
return ConstantTensorDescriptor_deprecated<
decltype(GetLengths().PushBack(leaf_tensor::GetLengths())),
decltype(GetStrides().PushBack(leaf_tensor::GetStrides()))>{};
}
template <index_t IDimVector, index_t DataPerVector>
struct lambda_IsVectorizationAllowed
{
bool& is_allowed;
__host__ __device__ constexpr lambda_IsVectorizationAllowed(bool& is_allowed_)
: is_allowed(is_allowed_)
{
}
template <index_t IDim_>
__host__ __device__ constexpr void operator()(Number<IDim_>) const
{
constexpr auto IDim = Number<IDim_>{};
if(IDimVector != IDim && Strides::Get(IDim) % DataPerVector != 0)
{
is_allowed = false;
}
}
};
template <index_t IDimVector, index_t DataPerVector>
__host__ __device__ static constexpr bool IsVectorizationAllowed(Number<IDimVector>,
Number<DataPerVector>)
{
bool is_allowed = (Strides{}[IDimVector] == 1 || DataPerVector == 1) &&
Lengths{}[IDimVector] % DataPerVector == 0;
static_for<0, nDim, 1>{}(
lambda_IsVectorizationAllowed<IDimVector, DataPerVector>{is_allowed});
return is_allowed;
}
template <index_t IDim, index_t DataPerVector>
__host__ __device__ static constexpr auto Vectorize(Number<IDim>, Number<DataPerVector>)
{
constexpr auto idim = Number<IDim>{};
constexpr auto data_per_vector = Number<DataPerVector>{};
static_assert(IsVectorizationAllowed(idim, data_per_vector), "wrong!");
using vectorized_lengths =
decltype(Lengths::Modify(Number<IDim>{}, Number<Lengths{}[IDim] / DataPerVector>{}));
using vectorized_strides =
decltype((Strides{} / Number<DataPerVector>{}).Modify(Number<IDim>{}, Number<1>{}));
return ConstantTensorDescriptor_deprecated<vectorized_lengths, vectorized_strides>{};
}
template <index_t IDim, index_t SliceLen>
__host__ __device__ static constexpr auto Slice(Number<IDim>, Number<SliceLen>)
{
using slice_lengths = decltype(Lengths::Modify(Number<IDim>{}, Number<SliceLen>{}));
return ConstantTensorDescriptor_deprecated<slice_lengths, Strides>{};
}
template <index_t... Is>
__host__ __device__ static constexpr auto Slice(Sequence<Is...> slice_lengths)
{
static_assert(slice_lengths.GetSize() == nDim, "wrong!");
return ConstantTensorDescriptor_deprecated<decltype(slice_lengths), Strides>{};
}
template <index_t IDim, index_t SliceLength, index_t SliceStride>
__host__ __device__ static constexpr auto
StridedSlice(Number<IDim>, Number<SliceLength>, Number<SliceStride>)
{
constexpr index_t new_stride = Strides::Get(Number<IDim>{}) * SliceStride;
using new_lengths = decltype(Lengths::Modify(Number<IDim>{}, Number<SliceLength>{}));
using new_strides = decltype(Strides::Modify(Number<IDim>{}, Number<new_stride>{}));
return ConstantTensorDescriptor_deprecated<new_lengths, new_strides>{};
}
template <index_t IDim, index_t... FoldIntervals>
__host__ __device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
{
constexpr auto fold_intervals = Sequence<FoldIntervals...>{};
constexpr index_t fold_intervals_product =
reduce_on_sequence(fold_intervals, math::multiplies<index_t>{}, Number<1>{});
constexpr auto unfold_length = GetLength(Number<IDim>{});
constexpr auto unfold_stride = GetStride(Number<IDim>{});
// length of the dimension to be folded needs to be dividable by fold_interval_product,
// otherwise, folding is invalid
static_assert(unfold_length % fold_intervals_product == 0,
"wrong! length on the dimension to be folded cannot be evenly divided!");
// folded lengths
constexpr auto fold_lengths =
Sequence<unfold_length / fold_intervals_product>{}.PushBack(fold_intervals);
// folded strides
constexpr auto fold_strides =
Number<unfold_stride>{} *
reverse_inclusive_scan_sequence(
fold_intervals.PushBack(Number<1>{}), math::multiplies<index_t>{}, Number<1>{});
// left and right
constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::type{};
constexpr auto right =
typename arithmetic_sequence_gen<IDim + 1, GetNumOfDimension(), 1>::type{};
constexpr auto new_lengths =
GetLengths().Extract(left).PushBack(fold_lengths).PushBack(GetLengths().Extract(right));
constexpr auto new_strides =
GetStrides().Extract(left).PushBack(fold_strides).PushBack(GetStrides().Extract(right));
return ConstantTensorDescriptor_deprecated<decltype(new_lengths), decltype(new_strides)>{};
}
template <index_t IDim, index_t... FoldIntervals>
__host__ __device__ static constexpr auto Fold(Number<IDim>, Sequence<FoldIntervals...>)
{
return Fold(Number<IDim>{}, Number<FoldIntervals>{}...);
}
// this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension
template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
__host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
{
static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim &&
FirstUnfoldDim <= LastUnfoldDim,
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!");
// left and right
constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::type{};
constexpr auto middle =
typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::type{};
constexpr auto right =
typename arithmetic_sequence_gen<LastUnfoldDim + 1, GetNumOfDimension(), 1>::type{};
// dimensions to be unfolded need to be continuous
static_assert(Type::Extract(middle).AreDimensionsContinuous(), "wrong! not unfoldable");
// unfolded length, stride
constexpr index_t unfold_length = reduce_on_sequence(
GetLengths().Extract(middle), math::multiplies<index_t>{}, Number<1>{});
constexpr index_t unfold_stride = GetStride(Number<LastUnfoldDim>{});
// new lengths, strides
constexpr auto new_lengths = GetLengths()
.Extract(left)
.PushBack(Number<unfold_length>{})
.PushBack(GetLengths().Extract(right));
constexpr auto new_strides = GetStrides()
.Extract(left)
.PushBack(Number<unfold_stride>{})
.PushBack(GetStrides().Extract(right));
return ConstantTensorDescriptor_deprecated<decltype(new_lengths), decltype(new_strides)>{};
}
__host__ __device__ static constexpr auto Pack()
{
using packed_strides = decltype(calculate_tensor_strides_packed_deprecated(Lengths{}));
return ConstantTensorDescriptor_deprecated<Lengths, packed_strides>{};
}
template <class MapNew2Old>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old)
{
return ConstantTensorDescriptor_deprecated<
decltype(Lengths::ReorderGivenNew2Old(MapNew2Old{})),
decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{};
}
template <class MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{
return ConstantTensorDescriptor_deprecated<
decltype(Lengths::ReorderGivenOld2New(MapOld2New{})),
decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{};
}
};
template <class Lengths>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_packed(Lengths)
{
using Strides = decltype(calculate_tensor_strides_packed_deprecated(Lengths{}));
return ConstantTensorDescriptor_deprecated<Lengths, Strides>{};
}
template <class Lengths, class Strides>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides)
{
return ConstantTensorDescriptor_deprecated<Lengths, Strides>{};
}
template <class Lengths, index_t Align>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
{
using Strides = decltype(calculate_tensor_strides_aligned_old(Lengths{}, Number<Align>{}));
return ConstantTensorDescriptor_deprecated<Lengths, Strides>{};
}
template <index_t... Lengths, index_t... Strides>
__host__ __device__ void print_ConstantTensorDescriptor(
const char* s, ConstantTensorDescriptor_deprecated<Sequence<Lengths...>, Sequence<Strides...>>)
{
constexpr index_t ndim = sizeof...(Lengths);
static_assert(ndim > 0 && ndim <= 12, "wrong!");
static_if<ndim == 1>{}([&](auto) {
printf("%s dim %u, lengths {%u}, strides {%u}\n", s, ndim, Lengths..., Strides...);
});
static_if<ndim == 2>{}([&](auto) {
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, ndim, Lengths..., Strides...);
});
static_if<ndim == 3>{}([&](auto) {
printf(
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, ndim, Lengths..., Strides...);
});
static_if<ndim == 4>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 5>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 6>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 7>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 8>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 9>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 10>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
"%u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 11>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
"%u %u "
"%u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 12>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
"%u %u %u %u "
"%u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
}
} // namespace ck
#endif
#ifndef CK_TENSOR_COORDINATE_DEPRECATED_HPP
#define CK_TENSOR_COORDINATE_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
namespace ck {
// TensorDesc is ConstantTensorDescriptor_deprecated
template <class TensorDesc>
struct NormalTensorCoordinate_deprecated
{
using type = NormalTensorCoordinate_deprecated;
using tensor_desc_type = TensorDesc;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
__host__
__device__ constexpr NormalTensorCoordinate_deprecated(Array<index_t, nDim> tensor_index)
: mOffset{tensor_desc_type::GetOffsetFromMultiIndex(tensor_index)}
{
}
template <class... Xs>
__host__ __device__ constexpr NormalTensorCoordinate_deprecated(Xs... xs)
: NormalTensorCoordinate_deprecated(Array<index_t, nDim>{xs...})
{
}
template <index_t... Xs>
__host__ __device__ constexpr NormalTensorCoordinate_deprecated(Sequence<Xs...>)
: NormalTensorCoordinate_deprecated(Array<index_t, nDim>{Xs...})
{
}
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
// T is Array or Sequence
template <class T>
__host__ __device__ type operator+=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
mOffset += tensor_desc_type::GetOffsetFromMultiIndex(step_sizes);
return *this;
}
template <class T>
__host__ __device__ type operator-=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
mOffset -= tensor_desc_type::GetOffsetFromMultiIndex(step_sizes);
return *this;
}
template <class T>
__host__ __device__ constexpr type operator+(T step_sizes) const
{
type coord = *this;
coord += step_sizes;
return coord;
}
template <class T>
__host__ __device__ constexpr type operator-(T step_sizes) const
{
type coord = *this;
coord -= step_sizes;
return coord;
}
// reposition point of origin, and return compensated offset.
// This is a hack to reduce index calculation during looping over
// a tensor whose origin is this TensorCoordinate. It does so, by spitting
// out the run-time offset to the pointer (to the tensor data) held by this
// TensorCoordiante, so the caller can add the offset into the run-time pointer of
// the data, so only 1 run-time variable (update pointer) is needed, instead
// of 2 run-time variables (old pointer and this offset)
// TODO: after introducing the concept of "run-time tensor view", which contains the
// run-time pointer to the data, always keep track of the pointer, instead of both
// offset and the pointer. This also bring additional benefit that we don't need to
// worry the offset might underflow (because offset is unsigned integer) when updating it.
__host__ __device__ constexpr index_t RepositionOrigin()
{
index_t offset_diff = mOffset;
mOffset = 0;
return offset_diff;
}
private:
index_t mOffset;
};
// TensorDesc is ConstantMergedTensorDescriptor_deprecated
template <class TensorDesc>
struct MergedTensorCoordinate_deprecated
{
using type = MergedTensorCoordinate_deprecated;
using tensor_desc_type = TensorDesc;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
static constexpr index_t nOriginalDim =
tensor_desc_type::GetOriginalTensorDescriptor().GetNumOfDimension();
__host__
__device__ constexpr MergedTensorCoordinate_deprecated(Array<index_t, nDim> tensor_index)
: mOriginalIndex{tensor_desc_type::GetOriginalMultiIndexFromMultiIndex(tensor_index)}
{
// partial offset on each dimension
static_for<0, nDim, 1>{}([&](auto idim) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
mPartialOffsets(idim) = partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mOriginalIndex, partial_original_dims));
});
// complete offset
mOffset =
accumulate_on_array(mPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
}
template <class... Xs>
__host__ __device__ constexpr MergedTensorCoordinate_deprecated(Xs... xs)
: MergedTensorCoordinate_deprecated(Array<index_t, nDim>{xs...})
{
}
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
template <class IDim, class T, bool PositiveDirection>
__host__ __device__ void
MoveOnDimension(IDim idim_, T step_size, integral_constant<bool, PositiveDirection>)
{
constexpr auto idim = idim_;
// if step_size is known at compile time
static_if<is_static<T>::value>{}(
[&](auto) { static_if<T{} == 0>{}([&](auto) { return; }); });
// update original index
static_if<tensor_desc_type::ContainMultipleOriginalDimensions(idim)>{}([&](auto) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr index_t ndim_partial_original = partial_original_dims.GetSize();
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
const auto partial_original_step_sizes =
partial_original_desc.GetMultiIndexFrom1dIndex(step_size);
// update partial original multi-id
auto partial_original_id = extract_array(mOriginalIndex, partial_original_dims);
static_if<PositiveDirection>{}([&](auto) {
partial_original_id += partial_original_step_sizes;
bool carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original - 1, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(carry)
{
++partial_original_id(i);
}
carry = false;
if(partial_original_id[i] >= partial_original_desc.GetLength(i))
{
partial_original_id(i) -= partial_original_desc.GetLength(i);
carry = true;
}
});
// highest dimension
if(carry)
{
++partial_original_id(0);
}
}).Else([&](auto) {
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
partial_original_id +=
partial_original_desc.GetLengths() - partial_original_step_sizes;
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original - 1, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(borrow)
{
--partial_original_id(i);
}
borrow = false;
if(partial_original_id[i] < partial_original_desc.GetLength(i))
{
partial_original_id(i) += partial_original_desc.GetLength(i);
borrow = true;
}
});
// highest dimension
if(borrow)
{
--partial_original_id(0);
}
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
partial_original_id = partial_original_id - partial_original_desc.GetLengths();
});
// update "mOriginalIndex"
static_for<0, ndim_partial_original, 1>{}([&](auto I) {
constexpr auto idim_original = partial_original_dims[I];
mOriginalIndex(idim_original) = partial_original_id[I];
});
// calculate new partial offset on this merged dimension
const index_t old_partial_offset = mPartialOffsets[idim];
mPartialOffsets(idim) =
partial_original_desc.GetOffsetFromMultiIndex(partial_original_id);
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mOffset = (mOffset + mPartialOffsets[idim]) - old_partial_offset;
}).Else([&](auto fwd) {
static_if<PositiveDirection>{}([&](auto) {
mOffset += step_size * fwd(tensor_desc_type{}).GetStride(idim);
}).Else([&](auto) { mOffset -= step_size * fwd(tensor_desc_type{}).GetStride(idim); });
});
}
// T is Array or Sequence
template <class T>
__host__ __device__ type operator+=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
static_for<0, nDim, 1>{}([&](auto idim) {
// compiler should remove dead code path, because step_sizes is known at
// compile time
if(step_sizes[idim] != 0)
{
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, true>{});
}
});
return *this;
}
template <class T>
__host__ __device__ type operator-=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
static_for<0, nDim, 1>{}([&](auto idim) {
// compiler should remove dead code path, because step_sizes is known at
// compile time
if(step_sizes[idim] != 0)
{
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, false>{});
}
});
return *this;
}
template <class T>
__host__ __device__ constexpr type operator+(T step_sizes) const
{
type coord = *this;
coord += step_sizes;
return coord;
}
template <class T>
__host__ __device__ constexpr type operator-(T step_sizes) const
{
type coord = *this;
coord -= step_sizes;
return coord;
}
__host__ __device__ static constexpr index_t RepositionOrigin() { return 0; }
private:
// Allocate register memory for all merged dimensions and normal dimensions.
// However, only those merged dimensions, whose index will be involved in arithmetic
// after the construction of this TensorCoordinate (e.g. when user move a slicing
// window on the merged dimension), will use these register memory.
// Let's hope compiler will optimize away those register memory allocated for normal
// dimensions, and those merged dimensions, that would never be involved in index
// arithmetic after construction of TensorCoordinate.
// TODO: refactor TensorCoordinate, after introducing the concept of "dimensions"
// and simplify implementation of ConstantMergedTensorDescriptor_deprecated, so we don't need to
// count on compiler to optimize away those register memory for us
Array<index_t, nOriginalDim> mOriginalIndex;
Array<index_t, nDim> mPartialOffsets;
// complete offset
index_t mOffset;
};
template <class TensorDesc>
struct TensorCoordinate_deprecated
{
private:
template <class... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated<Ts...>)
{
return NormalTensorCoordinate_deprecated<ConstantTensorDescriptor_deprecated<Ts...>>();
}
template <class... Ts>
__host__ __device__ static constexpr auto
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated<Ts...>)
{
return MergedTensorCoordinate_deprecated<
ConstantMergedTensorDescriptor_deprecated<Ts...>>();
}
public:
using type = decltype(MakeDummyTensorCoordinate(TensorDesc{}));
};
} // namespace ck
#endif
#ifndef CK_TENSOR_COORDINATE_HELPER_HPP
#define CK_TENSOR_COORDINATE_HELPER_HPP
#include "tensor_coordiante_hpp"
namespace ck {
template <typename TensorDesc>
__host__ __device__ constexpr auto
make_tensor_coordinate(TensorDesc, MultiIndex<TensorDesc::GetNumOfDimension()> idx)
{
return typename TensorCoordinate<TensorDesc>::type(idx);
}
} // namespace ck
#endif
...@@ -18,11 +18,11 @@ template <index_t BlockSize, ...@@ -18,11 +18,11 @@ template <index_t BlockSize,
typename ThreadMatrixC, typename ThreadMatrixC,
index_t MPerThreadSubC, index_t MPerThreadSubC,
index_t NPerThreadSubC, index_t NPerThreadSubC,
index_t KPerThreadLoop,
index_t MLevel0ThreadCluster, index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster, index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster, index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster, index_t NLevel1ThreadCluster,
index_t KPerThreadLoop,
index_t ThreadGemmADataPerRead_M, index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N> index_t ThreadGemmBDataPerRead_N>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
......
...@@ -15,6 +15,8 @@ namespace ck { ...@@ -15,6 +15,8 @@ namespace ck {
// The dimension access order can be different for src and dst. // The dimension access order can be different for src and dst.
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping // Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping // Will do valid mapping check on dst data: No write if dst data has a invalid mapping
// BlockSize can be equal or larger than ThreadCluster size, which means some threads may not do
// threadwise copy
template <index_t BlockSize, template <index_t BlockSize,
typename BlockSrcDesc, typename BlockSrcDesc,
typename BlockDstDesc, typename BlockDstDesc,
...@@ -31,7 +33,9 @@ template <index_t BlockSize, ...@@ -31,7 +33,9 @@ template <index_t BlockSize,
AddressSpace SrcAddressSpace = AddressSpace::Generic, AddressSpace SrcAddressSpace = AddressSpace::Generic,
AddressSpace ThreadBufferAddressSpace = AddressSpace::Generic, AddressSpace ThreadBufferAddressSpace = AddressSpace::Generic,
AddressSpace DstAddressSpace = AddressSpace::Generic, AddressSpace DstAddressSpace = AddressSpace::Generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set> InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set,
index_t SrcDataStride = 1,
index_t DstDataStride = 1>
struct BlockwiseGenericTensorSliceCopy_v4 struct BlockwiseGenericTensorSliceCopy_v4
{ {
static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension(); static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension();
...@@ -52,23 +56,23 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -52,23 +56,23 @@ struct BlockwiseGenericTensorSliceCopy_v4
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{}, is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
// map threads to cluster static_assert(BlockSize >= mThreadClusterDesc.GetElementSize(),
constexpr auto thread_cluster_desc = "wrong! BlockSize too small");
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static_assert(BlockSize == thread_cluster_desc.GetElementSize(), if(BlockSize == mThreadClusterDesc.GetElementSize() or
"wrong! BlockSize not consistent with ThreadClusterLengths"); get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
const auto thread_cluster_id = const auto thread_cluster_id =
thread_cluster_desc.CalculateClusterIndex(get_thread_local_1d_id()); mThreadClusterDesc.CalculateClusterIndex(get_thread_local_1d_id());
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{}; const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin); mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>()); mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>()); mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>());
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin); mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
}
} }
__device__ static constexpr index_t GetThreadBufferSize() __device__ static constexpr index_t GetThreadBufferSize()
...@@ -83,14 +87,18 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -83,14 +87,18 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr bool has_optimized_address_calculation = constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation(); decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
// TODO: threadwise copy is still being tweaked if(BlockSize == mThreadClusterDesc.GetElementSize() or
if(has_optimized_address_calculation) get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseLoad.Run_optimized_src_address_calculation(p_block_src, p_thread_buffer);
}
else
{ {
mThreadwiseLoad.Run(p_block_src, p_thread_buffer); // TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation)
{
mThreadwiseLoad.Run_optimized_src_address_calculation(p_block_src, p_thread_buffer);
}
else
{
mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
}
} }
} }
...@@ -101,14 +109,19 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -101,14 +109,19 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr bool has_optimized_address_calculation = constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation(); decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
// TODO: threadwise copy is still being tweaked if(BlockSize == mThreadClusterDesc.GetElementSize() or
if(has_optimized_address_calculation) get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer, p_block_dst);
}
else
{ {
mThreadwiseStore.Run(p_thread_buffer, p_block_dst); // TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation)
{
mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer,
p_block_dst);
}
else
{
mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
}
} }
} }
...@@ -123,10 +136,14 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -123,10 +136,14 @@ struct BlockwiseGenericTensorSliceCopy_v4
BlockSrcData p_thread_buffer[GetThreadBufferSize()]; BlockSrcData p_thread_buffer[GetThreadBufferSize()];
RunLoadThreadBuffer(p_block_src, p_thread_buffer); if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
RunLoadThreadBuffer(p_block_src, p_thread_buffer);
// if there is type conversion, it's done during store // if there is type conversion, it's done during store
RunStoreThreadBuffer(p_thread_buffer, p_block_dst); RunStoreThreadBuffer(p_thread_buffer, p_block_dst);
}
} }
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
...@@ -134,7 +151,11 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -134,7 +151,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
MoveSrcSliceWindow(const T& step_sizes, MoveSrcSliceWindow(const T& step_sizes,
integral_constant<bool, PositiveDirection> positive_direction) integral_constant<bool, PositiveDirection> positive_direction)
{ {
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction); if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction);
}
} }
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
...@@ -142,7 +163,11 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -142,7 +163,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
MoveDstSliceWindow(const T& step_sizes, MoveDstSliceWindow(const T& step_sizes,
integral_constant<bool, PositiveDirection> positive_direction) integral_constant<bool, PositiveDirection> positive_direction)
{ {
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction); if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
}
} }
private: private:
...@@ -157,7 +182,9 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -157,7 +182,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
1, 1,
SrcAddressSpace, SrcAddressSpace,
ThreadBufferAddressSpace, ThreadBufferAddressSpace,
InMemoryDataOperation::Set>; InMemoryDataOperation::Set,
SrcDataStride,
1>;
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc, using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc,
BlockDstDesc, BlockDstDesc,
...@@ -168,7 +195,12 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -168,7 +195,12 @@ struct BlockwiseGenericTensorSliceCopy_v4
DstDataPerWrite, DstDataPerWrite,
ThreadBufferAddressSpace, ThreadBufferAddressSpace,
DstAddressSpace, DstAddressSpace,
DstInMemOp>; DstInMemOp,
1,
DstDataStride>;
static constexpr auto mThreadClusterDesc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
ThreadwiseLoad mThreadwiseLoad; ThreadwiseLoad mThreadwiseLoad;
ThreadwiseStore mThreadwiseStore; ThreadwiseStore mThreadwiseStore;
......
#ifndef CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP
#define CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "tensor_coordinate_deprecated.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
namespace ck {
// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst.
// This functions assume each thread is reading and writing a normal (not merged) tensor,
// to simplify index calculations. To satisfy this assumption, the user need to make sure
// that, on a merged dimension that constains multiple original dimensions, the length of
// the last original dimension need to be evenly dividable by its sub-lengths. Also, the
// repeat-length on the merged dimension need to be 1. These sanity checks are performed
// in constructor of BlockwiseGenericTensorSliceCopy_v1_deprecated
template <index_t BlockSize,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename SubLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorAccessDim,
index_t DstVectorAccessDim,
index_t SrcDataPerAccess,
index_t DstDataPerAccess>
struct BlockwiseGenericTensorSliceCopy_v1_deprecated
{
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
static constexpr index_t nOriginalDimSrc =
SrcDesc::GetOriginalTensorDescriptor().GetNumOfDimension();
static constexpr index_t nOriginalDimDst =
DstDesc::GetOriginalTensorDescriptor().GetNumOfDimension();
// per-thread offset
index_t mThreadSrcOffset;
index_t mThreadDstOffset;
// "mThreadSrcOriginalMultiId", "mThreadSrcPartialOffsets, "mThreadDstOriginalMultiId",
// "mThreadDstPartialOffsets" are always calculated inside constructor, and would be
// updated if slicing-window is moved. However, they will not be used if you always move
// the slicing-window along a non-merged dimension. In that case, compiler should be
// able to remove these calculation.
// TODO: make sure compiler would actually remove them in that case
// partial offset in each (merged) dimension
Array<index_t, nDim> mThreadSrcPartialOffsets;
Array<index_t, nDim> mThreadDstPartialOffsets;
// multi-id of original tensor
Array<index_t, nOriginalDimSrc> mThreadSrcOriginalMultiId;
Array<index_t, nOriginalDimDst> mThreadDstOriginalMultiId;
__device__
BlockwiseGenericTensorSliceCopy_v1_deprecated(Array<index_t, nDim> src_block_data_id_begin,
Array<index_t, nDim> dst_block_data_id_begin)
{
// check NDim consistency
static_assert(
nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() &&
nDim == SliceLengths::GetSize() && nDim == SubLengths::GetSize() &&
nDim == ThreadClusterLengths::GetSize() &&
nDim == ThreadClusterArrangeOrder::GetSize() &&
nDim == SrcDimAccessOrder::GetSize() && nDim == DstDimAccessOrder::GetSize(),
"wrong");
// check thread arrange order and read/write access order are valid
static_assert(is_valid_sequence_map<ThreadClusterArrangeOrder>::value &&
is_valid_sequence_map<SrcDimAccessOrder>::value &&
is_valid_sequence_map<DstDimAccessOrder>::value,
"wrong!");
// thread cluster
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
// BlockSize
static_assert(BlockSize == thread_cluster_desc.GetElementSize(), "wrong! BlockSize");
// divide work
constexpr auto data_per_cluster_per_dims = SubLengths{} * ThreadClusterLengths{};
static_for<0, nDim, 1>{}([&](auto IDim) {
static_assert(SliceLengths::Get(IDim) % data_per_cluster_per_dims.Get(IDim) == 0,
"wrong! cannot evenly divide sliced tensor into cluster");
});
constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims;
// additional check for merged dimension
static_for<0, nDim, 1>{}([&](auto IDim_) {
// src
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
constexpr auto IDim = decltype(IDim_){};
// on a merged dimension that constains multiple original dimensions,
// the length of the last original dimension need to evenly dividable by its
// sub-length,
// so each thread is effectively reading a normal (not merged) tensor
constexpr auto idim_last_original_src =
SrcDesc::GetContainedOriginalDimensions(IDim).Back();
static_assert(
SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_src) %
SubLengths::Get(IDim) ==
0,
"wrong!");
// merged dimension should have repeat_lengths = 1
static_assert(repeat_lengths[IDim] == 1,
"wrong! repeat_lengths shoud be 1 on merged dimension");
});
// dst
static_if<DstDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
constexpr auto IDim = decltype(IDim_){};
// on a merged dimension that constains multiple original dimensions,
// the length of the last original dimension need to evenly dividable by its
// sub-length,
// so each thread is effectively reading a normal (not merged) tensor
constexpr auto idim_last_original_dst =
DstDesc::GetContainedOriginalDimensions(IDim).Back();
static_assert(
DstDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_dst) %
SubLengths::Get(IDim) ==
0,
"wrong!");
// merged dimension should have repeat_lengths = 1
static_assert(repeat_lengths[IDim] == 1,
"wrong! repeat_lengths shoud be 1 on merged dimension");
});
});
// calculate mThreadSrcOffset, mThreadDstOffset
const auto thread_cluster_id =
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
const auto data_cluster_id =
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
// original multi-id
mThreadSrcOriginalMultiId = SrcDesc::GetOriginalMultiIndexFromMultiIndex(
src_block_data_id_begin + thread_data_id_begin);
mThreadDstOriginalMultiId = DstDesc::GetOriginalMultiIndexFromMultiIndex(
dst_block_data_id_begin + thread_data_id_begin);
// partial offset on each dimension
static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto src_partial_original_dims =
SrcDesc::GetContainedOriginalDimensions(IDim);
constexpr auto src_partial_original_desc =
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
mThreadSrcPartialOffsets(IDim) = src_partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims));
});
static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto dst_partial_original_dims =
DstDesc::GetContainedOriginalDimensions(IDim);
constexpr auto dst_partial_original_desc =
DstDesc::GetOriginalTensorDescriptor().Extract(dst_partial_original_dims);
mThreadDstPartialOffsets(IDim) = dst_partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mThreadDstOriginalMultiId, dst_partial_original_dims));
});
// complete offset
mThreadSrcOffset = accumulate_on_array(
mThreadSrcPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
mThreadDstOffset = accumulate_on_array(
mThreadDstPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
}
__device__ static constexpr auto GetRegisterBufferDescriptor()
{
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
return make_ConstantTensorDescriptor_packed(SubLengths{} * repeat_lengths);
}
__device__ static constexpr index_t GetThreadBufferSize()
{
return GetRegisterBufferDescriptor().GetElementSpace();
}
template <typename TData>
__device__ void RunLoadThreadBuffer(const TData* __restrict__ p_src,
TData* __restrict__ p_buffer) const
{
constexpr auto thread_sub_tensor_lengths = SubLengths{};
constexpr auto data_per_cluster_per_dims =
thread_sub_tensor_lengths * ThreadClusterLengths{};
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
constexpr auto thread_buffer_desc = GetRegisterBufferDescriptor();
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
constexpr auto src_thread_data_id_begin = repeat_id * data_per_cluster_per_dims;
constexpr auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
constexpr index_t src_offset =
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_id_begin);
constexpr index_t buffer_offset =
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
#else
ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
const auto src_thread_data_id_begin = repeat_id * data_per_cluster_per_dims;
const auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
const index_t src_offset = SrcDesc::GetOffsetFromMultiIndex(src_thread_data_id_begin);
const index_t buffer_offset =
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// To satisfy this assumption, the user need to make sure that, on a merged dimension
// that constains multiple original dimensions, the length of the last original
// dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on
// the merged dimension need to be 1. These sanity checks are performed in constructor
// of BlockwiseGenericTensorSliceCopy_v1_deprecated
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<SrcDesc,
decltype(thread_buffer_desc),
SubLengths,
SrcDimAccessOrder,
SrcVectorAccessDim,
SrcDataPerAccess,
1>(make_zero_array<index_t, nDim>(),
make_zero_array<index_t, nDim>())
.Run(p_src + src_offset + mThreadSrcOffset, p_buffer + buffer_offset);
});
}
template <typename TData>
__device__ void RunStoreThreadBuffer(const TData* __restrict__ p_buffer,
TData* __restrict__ p_dst) const
{
constexpr auto thread_sub_tensor_lengths = SubLengths{};
constexpr auto data_per_cluster_per_dims =
thread_sub_tensor_lengths * ThreadClusterLengths{};
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
constexpr auto thread_buffer_desc = GetRegisterBufferDescriptor();
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
constexpr auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
constexpr auto dst_data_id_begin = repeat_id * data_per_cluster_per_dims;
constexpr index_t buffer_offset =
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
constexpr index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_id_begin);
#else
ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
const auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
const auto dst_data_id_begin = repeat_id * data_per_cluster_per_dims;
const index_t buffer_offset =
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_id_begin);
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// To satisfy this assumption, the user need to make sure that, on a merged dimension
// that constains multiple original dimensions, the length of the last original
// dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on
// the merged dimension need to be 1. These sanity checks are performed in constructor
// of BlockwiseGenericTensorSliceCopy_v1_deprecated
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<decltype(thread_buffer_desc),
DstDesc,
SubLengths,
DstDimAccessOrder,
DstVectorAccessDim,
1,
DstDataPerAccess>(
make_zero_array<index_t, nDim>(), make_zero_array<index_t, nDim>())
.Run(p_buffer + buffer_offset, p_dst + dst_offset + mThreadDstOffset);
});
}
template <typename TData>
__device__ void Run(const TData* __restrict__ p_src, TData* __restrict__ p_dst) const
{
TData p_buffer[GetThreadBufferSize()];
RunLoadThreadBuffer(p_src, p_buffer);
RunStoreThreadBuffer(p_buffer, p_dst);
}
// When moving the slicing windows along a merged dimension, if the strides of the
// contained (by the merged dimension) original dimensions are not in descending order,
// then there is no guarantee that the new offset will be larger than the old offset
// for movement in positive direction (vice versue for movement in negative direction).
// As a result, there is the possiblity that the offset calculation may result in
// unsigned integer underflow (due to "-" operation). However, this hazard should not
// happen, as long as the users make sure the slicing window would not be moved out of
// the boundary of the tensor being sliced. This functions doesn't do runtime sanity
// check on out-of-bound slicing window, for performance reason
template <index_t IDim_, index_t StepSize, bool PositiveDirection>
__device__ void MoveSlicingWindowOnSourceTensor(
Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
{
constexpr auto IDim = Number<IDim_>{};
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim)>{}([&](auto) {
// logic for a merged dimension, also works for non-merged dimension, but its logic may
// be unncessarily complicated for compiler to remove calculations that are useless for
// a non-merged dimension
// extract partial original dimensions
constexpr auto src_partial_original_dims =
SrcDesc::GetContainedOriginalDimensions(IDim);
constexpr auto src_partial_original_desc =
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
// calculate new partial original multi-id
auto old_src_partial_original_id =
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims);
auto new_src_partial_original_id =
src_partial_original_desc.UpdateMultiIndexGivenStepSizeOf1dIndex(
old_src_partial_original_id, StepSize, direction);
// update "mThreadSrcOriginalMultiId"
static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I) {
constexpr auto IDimOriginal = src_partial_original_dims[I];
mThreadSrcOriginalMultiId(IDimOriginal) = new_src_partial_original_id[I];
});
// calculate new partial offset on this merged dimension
const index_t old_src_partial_offset = mThreadSrcPartialOffsets[IDim];
const index_t new_src_partial_offset =
src_partial_original_desc.GetOffsetFromMultiIndex(new_src_partial_original_id);
// update "mThreadSrcPartialOffsets"
mThreadSrcPartialOffsets(IDim) = new_src_partial_offset;
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mThreadSrcOffset = (mThreadSrcOffset + new_src_partial_offset) - old_src_partial_offset;
}).Else([&](auto) {
// Logic for non-merged dimension. If you are never going to move the slicing window on
// a merged dimension, then "mThreadSrcOriginalMultiId" and "mThreadSrcPartialOffsets",
// which are being calculated here, will never be used later. In this case, compiler
// should be able to remove these calculations.
// TODO: make sure compiler would actually remove them in this case.
// It is the user's responsiblity to make sure the slicing window will not be moved out
// of the boundary of the tensor being sliced. Otherwise, there might be hazard like
// unsigned integer underflow. That is NO runtime sanity check to prevent the hazard
constexpr auto IDimOriginal = SrcDesc::GetContainedOriginalDimensions(IDim).Front();
static_if<PositiveDirection>{}([&](auto fwd) {
mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
mThreadSrcOriginalMultiId(IDimOriginal) += StepSize;
mThreadSrcPartialOffsets(IDim) += StepSize * fwd(SrcDesc{}).GetStride(IDim);
}).Else([&](auto fwd) {
mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
mThreadSrcOriginalMultiId(IDimOriginal) -= StepSize;
mThreadSrcPartialOffsets(IDim) -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
});
});
}
template <typename T, bool PositiveDirection>
__device__ void
MoveSrcSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection> positive_direction)
{
static_for<0, nDim, 1>{}([&](auto idim) {
if(step_sizes[idim] != 0)
{
MoveSlicingWindowOnSourceTensor(idim, step_sizes[idim], positive_direction);
}
});
}
};
// This version use TensorCoordiante
// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst.
template <index_t BlockSize,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename SubLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorAccessDim,
index_t DstVectorAccessDim,
index_t SrcDataPerAccess,
index_t DstDataPerAccess>
struct BlockwiseGenericTensorSliceCopy_v2_deprecated
{
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseGenericTensorSliceCopy_v2_deprecated(
const Index& src_block_slice_origin, const Index& dst_block_slice_origin)
{
static_assert(
nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() &&
nDim == SliceLengths::GetSize() && nDim == SubLengths::GetSize() &&
nDim == ThreadClusterLengths::GetSize() &&
nDim == ThreadClusterArrangeOrder::GetSize() &&
nDim == SrcDimAccessOrder::GetSize() && nDim == DstDimAccessOrder::GetSize(),
"wrong! nDim not consistent");
static_assert(is_same<SliceLengths, decltype(SubLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
"wrong! BlockSize not consistent with ThreadClusterLengths");
const auto thread_cluster_id =
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
const auto data_cluster_id =
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>());
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
}
__device__ static constexpr index_t GetThreadBufferSize()
{
return ThreadBufferDesc::GetElementSpace();
}
template <typename BlockSrcData,
typename ThreadBufferData,
AddressSpace BlockSrcAddressSpace,
AddressSpace ThreadBufferAddressSpace>
__device__ void
RunLoadThreadBuffer(const BlockSrcData* p_block_src,
ThreadBufferData* p_thread_buffer,
integral_constant<AddressSpace, BlockSrcAddressSpace>,
integral_constant<AddressSpace, ThreadBufferAddressSpace>) const
{
constexpr auto block_src_address_space =
integral_constant<AddressSpace, BlockSrcAddressSpace>{};
constexpr auto thread_buffer_address_space =
integral_constant<AddressSpace, ThreadBufferAddressSpace>{};
mThreadwiseLoad.Run(
p_block_src, p_thread_buffer, block_src_address_space, thread_buffer_address_space);
}
template <typename BlockSrcData, typename ThreadBufferData>
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
ThreadBufferData* p_thread_buffer) const
{
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::Generic>{};
RunLoadThreadBuffer(
p_block_src, p_thread_buffer, generic_address_space, generic_address_space);
}
template <typename ThreadBufferData,
typename BlockDstData,
AddressSpace ThreadBufferAddressSpace,
AddressSpace BlockDstAddressSpace>
__device__ void
RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
BlockDstData* p_block_dst,
integral_constant<AddressSpace, ThreadBufferAddressSpace>,
integral_constant<AddressSpace, BlockDstAddressSpace>) const
{
constexpr auto thread_buffer_address_space =
integral_constant<AddressSpace, ThreadBufferAddressSpace>{};
constexpr auto block_dst_address_space =
integral_constant<AddressSpace, BlockDstAddressSpace>{};
mThreadwiseStore.Run(
p_thread_buffer, p_block_dst, thread_buffer_address_space, block_dst_address_space);
}
template <typename ThreadBufferData, typename BlockDstData>
__device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
BlockDstData* p_block_dst) const
{
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::Generic>{};
RunStoreThreadBuffer(
p_thread_buffer, p_block_dst, generic_address_space, generic_address_space);
}
template <typename BlockSrcData,
typename BlockDstData,
AddressSpace BlockSrcAddressSpace,
AddressSpace BlockDstAddressSpace>
__device__ void
Run(const BlockSrcData* p_block_src,
BlockDstData* p_block_dst,
integral_constant<AddressSpace, BlockSrcAddressSpace> block_src_address_space,
integral_constant<AddressSpace, BlockDstAddressSpace> block_dst_address_space) const
{
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::Generic>{};
RunLoadThreadBuffer(
p_block_src, p_thread_buffer, block_src_address_space, generic_address_space);
// if there is type conversion, it's done during store
RunStoreThreadBuffer(
p_thread_buffer, p_block_dst, generic_address_space, block_dst_address_space);
}
template <typename BlockSrcData, typename BlockDstData>
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
{
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::Generic>{};
Run(p_block_src, p_block_dst, generic_address_space, generic_address_space);
}
template <typename T, bool PositiveDirection>
__device__ void
MoveSrcSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection> positive_direction)
{
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction);
}
template <typename T, bool PositiveDirection>
__device__ void
MoveDstSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection> positive_direction)
{
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
}
private:
using ThreadBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<SrcDesc,
ThreadBufferDesc,
SubLengths,
SrcDimAccessOrder,
SrcDimAccessOrder,
SrcVectorAccessDim,
SrcVectorAccessDim,
SrcDataPerAccess,
1>;
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<ThreadBufferDesc,
DstDesc,
SubLengths,
DstDimAccessOrder,
DstDimAccessOrder,
DstVectorAccessDim,
DstVectorAccessDim,
1,
DstDataPerAccess>;
ThreadwiseLoad mThreadwiseLoad;
ThreadwiseStore mThreadwiseStore;
};
} // namespace ck
#endif
...@@ -186,7 +186,6 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -186,7 +186,6 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
"wrong!"); "wrong!");
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
...@@ -201,11 +200,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -201,11 +200,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
decltype(c_m0m1_n0n1_thread_mtx_desc), decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThread, MPerThread,
NPerThread, NPerThread,
KPerThread,
MLevel0Cluster, MLevel0Cluster,
NLevel0Cluster, NLevel0Cluster,
MLevel1Cluster, MLevel1Cluster,
NLevel1Cluster, NLevel1Cluster,
KPerThread,
ThreadGemmAThreadCopySrcDataPerRead_M, ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N>{}; ThreadGemmBThreadCopySrcDataPerRead_N>{};
......
#ifndef CK_GRIDWISE_TENSOR_CONTRACTION_HPP
#define CK_GRIDWISE_TENSOR_CONTRACTION_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockLengths,
index_t KPerBlock,
InMemoryDataOperation CGlobalMemoryDataOperation>
struct GridwiseTensorContraction_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() {}
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block) const
{
/// \todo sanity-check on AGlobalDesc, BGlboalDesc, CGlobalDesc length consisitency
/// \todo santiy-check on CBlockLengtsh
constexpr auto True = integral_constant<bool, true>{};
constexpr auto a_global_desc = AGlobalDesc{};
constexpr auto b_global_desc = BGlobalDesc{};
constexpr auto c_global_desc = CGlobalDesc{};
constexpr auto K = a_global_desc.GetLengths()[0];
// don't do anything if K == 0
if(K == 0)
{
return;
}
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N);
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
AGlobalDesc,
decltype(a_block_desc),
decltype(a_k_m_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_K_M,
ABlockCopyThreadClusterLengths_K_M,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
Sequence<0, 1>,
ABlockCopySrcVectorReadDim,
1,
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_M,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, m_block_data_on_global}, {0, 0});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
decltype(b_k_n_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_K_N,
BBlockCopyThreadClusterLengths_K_N,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
Sequence<0, 1>,
BBlockCopySrcVectorReadDim,
1,
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_N,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, n_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc);
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
// sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * MPerThread>{}, Number<GemmNRepeat * NPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThread,
NPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
KPerThread,
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
constexpr auto a_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
constexpr auto b_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
// input: register to global memory
{
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
constexpr index_t M0 = M / M1;
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
constexpr index_t N0 = N / N1;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, MPerThread, GemmNRepeat, NPerThread>{});
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
make_tuple(UnMerge<Sequence<M0, M1>>{}, UnMerge<Sequence<N0, N1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
CThreadCopySrcDstAccessOrder,
CThreadCopySrcDstVectorReadWriteDim,
1,
CThreadCopyDstDataPerWrite,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0},
{m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1})
.Run(p_c_thread, p_c_global);
}
}
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_block[shared_block_size];
Run(p_a_global, p_b_global, p_c_global, p_shared_block);
}
};
} // namespace ck
#endif
...@@ -23,7 +23,9 @@ template <typename SrcDesc, ...@@ -23,7 +23,9 @@ template <typename SrcDesc,
index_t DstDataPerWrite, index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::Generic, AddressSpace SrcAddressSpace = AddressSpace::Generic,
AddressSpace DstAddressSpace = AddressSpace::Generic, AddressSpace DstAddressSpace = AddressSpace::Generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set> InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set,
index_t SrcDataStride = 1,
index_t DstDataStride = 1>
struct ThreadwiseGenericTensorSliceCopy_v4r2 struct ThreadwiseGenericTensorSliceCopy_v4r2
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -116,7 +118,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -116,7 +118,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
SrcDataPerRead, SrcDataPerRead,
SrcAddressSpace, SrcAddressSpace,
AddressSpace::Vgpr, AddressSpace::Vgpr,
InMemoryDataOperation::Set>( InMemoryDataOperation::Set,
SrcDataStride,
1>(
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset); p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
} }
} }
...@@ -148,7 +152,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -148,7 +152,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
DstDataPerWrite, DstDataPerWrite,
AddressSpace::Vgpr, AddressSpace::Vgpr,
DstAddressSpace, DstAddressSpace,
DstInMemOp>( DstInMemOp,
1,
DstDataStride>(
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset()); p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
} }
} }
......
#ifndef CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP
#define CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
#include "tensor_coordinate_deprecated.hpp"
namespace ck {
// This threadwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// The dimensions of vector access should be the same on src and dst.
// The dimension access order should be the same on src and dst.
// It is designed for cases, where one of src and dst is register, and
// the other is device memory or LDS
template <typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename DimAccessOrder,
index_t VectorAccessDim,
index_t SrcDataPerAccess,
index_t DstDataPerAccess>
struct ThreadwiseGenericTensorSliceCopy_v1r2_deprecated
{
static constexpr index_t nDim = SliceLengths::GetSize();
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r2_deprecated(
Array<index_t, nDim> src_slice_origin, Array<index_t, nDim> dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
{
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
nDim == DimAccessOrder::GetSize(),
"wrong! # of dimensions not the same");
static_assert(is_valid_sequence_map<DimAccessOrder>::value, "wrong! map is not valid");
static_assert(
SliceLengths{}[VectorAccessDim] % math::lcm(SrcDataPerAccess, DstDataPerAccess) == 0,
"wrong! cannot evenly divide");
// check vectorized memory access
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
static_if<!SrcDesc::ContainMultipleOriginalDimensions(vector_access_dim)>{}([&](auto fwd) {
static_assert(
(fwd(SrcDesc{}).GetStride(vector_access_dim) == 1 || SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
}).Else([&](auto fwd) {
static_assert((fwd(SrcDesc{}).GetLastOriginalDimensionStride(vector_access_dim) == 1 ||
SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
static_if<!DstDesc::ContainMultipleOriginalDimensions(vector_access_dim)>{}([&](auto fwd) {
static_assert(
(fwd(DstDesc{}).GetStride(vector_access_dim) == 1 || DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
}).Else([&](auto fwd) {
static_assert((fwd(DstDesc{}).GetLastOriginalDimensionStride(vector_access_dim) == 1 ||
DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r2_deprecated()
: ThreadwiseGenericTensorSliceCopy_v1r2_deprecated(make_zero_array<index_t, nDim>(),
make_zero_array<index_t, nDim>())
{
}
__device__ void SetSrcSliceOrigin(Array<index_t, nDim> src_slice_origin)
{
mSrcSliceOrigin = src_slice_origin;
}
__device__ void SetDstSliceOrigin(Array<index_t, nDim> dst_slice_origin)
{
mDstSliceOrigin = dst_slice_origin;
}
template <class SrcData, class DstData>
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
{
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{};
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
ford<decltype(long_vector_access_lengths), DimAccessOrder>{}(
[&](auto long_vector_access_id) {
// data id w.r.t slicing-window
auto long_vector_data_begin_id = long_vector_access_id;
long_vector_data_begin_id(vector_access_dim) =
long_vector_size * long_vector_access_id[vector_access_dim];
// buffer to hold a long-vector
SrcData p_src_long_vector[long_vector_size];
DstData p_dst_long_vector[long_vector_size];
// load data from src to the long-vector buffer
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(vector_access_dim) = i * src_data_per_access;
const index_t src_offset = SrcDesc::GetOffsetFromMultiIndex(
mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id));
const index_t buffer_offset = i * src_data_per_access;
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
*reinterpret_cast<const src_vector_t*>(&p_src[src_offset]);
}
// type conversion
for(index_t i = 0; i < long_vector_size; ++i)
{
p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]);
}
// store data from the long-vector buffer to dst
for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(vector_access_dim) = i * dst_data_per_access;
const index_t buffer_offset = i * dst_data_per_access;
const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(
mDstSliceOrigin + (long_vector_data_begin_id + scalar_id));
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
}
});
}
private:
Array<index_t, nDim> mSrcSliceOrigin;
Array<index_t, nDim> mDstSliceOrigin;
};
// This version use TensorCoordinate_deprecated
// This threadwise copy allow vector access of src and dst.
// It allows the dimensions of vector access to be different on src and dst.
// It also allows the vector size to be different on src and dst.
// It also allows order of access to be different on src and dst.
// It use register as buffer to hold all data moving from src to dst.
// It is designed for copying small amount of data, and src and dst are
// device memory or LDS.
// When copying large amout of data, let's hope compiler will reduce register
// used for the buffer.
template <typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorAccessDim,
index_t DstVectorAccessDim,
index_t SrcDataPerAccess,
index_t DstDataPerAccess>
struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated
{
static constexpr index_t nDim = SliceLengths::GetSize();
using Index = MultiIndex<nDim>;
using SrcCoordinate = typename TensorCoordinate_deprecated<SrcDesc>::type;
using DstCoordinate = typename TensorCoordinate_deprecated<DstDesc>::type;
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1_deprecated(
const Index& src_slice_origin, const Index& dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
{
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
nDim == SrcDimAccessOrder::GetSize() &&
nDim == DstDimAccessOrder::GetSize(),
"wrong! # of dimensions not the same");
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::value &&
is_valid_sequence_map<DstDimAccessOrder>::value,
"wrong! map is not valid");
static_assert(SliceLengths{}[SrcVectorAccessDim] % SrcDataPerAccess == 0 &&
SliceLengths{}[DstVectorAccessDim] % DstDataPerAccess == 0,
"wrong! cannot evenly divide");
// check vectorized memory access
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
static_if<!SrcDesc::ContainMultipleOriginalDimensions(src_vector_access_dim)>{}(
[&](auto fwd) {
static_assert(
(fwd(SrcDesc{}).GetStride(src_vector_access_dim) == 1 || SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
})
.Else([&](auto fwd) {
static_assert(
(fwd(SrcDesc{}).GetLastOriginalDimensionStride(src_vector_access_dim) == 1 ||
SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
static_if<!DstDesc::ContainMultipleOriginalDimensions(dst_vector_access_dim)>{}(
[&](auto fwd) {
static_assert(
(fwd(DstDesc{}).GetStride(dst_vector_access_dim) == 1 || DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
})
.Else([&](auto fwd) {
static_assert(
(fwd(DstDesc{}).GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 ||
DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1_deprecated()
: ThreadwiseGenericTensorSliceCopy_v2r1_deprecated(make_zero_array<index_t, nDim>(),
make_zero_array<index_t, nDim>())
{
}
__device__ void SetSrcSliceOrigin(SrcCoordinate src_slice_origin)
{
mSrcSliceOrigin = src_slice_origin;
}
__device__ void SetDstSliceOrigin(DstCoordinate dst_slice_origin)
{
mDstSliceOrigin = dst_slice_origin;
}
template <typename TDesc, class Lengths>
struct IsolateMergedDimLengths
{
template <typename IDim>
__device__ constexpr index_t operator()(IDim idim) const
{
return TDesc::ContainMultipleOriginalDimensions(idim) ? Lengths{}[idim] : 1;
}
};
template <typename SrcData,
typename DstData,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace>
__device__ void Run(const SrcData* p_src,
DstData* p_dst,
integral_constant<AddressSpace, SrcAddressSpace>,
integral_constant<AddressSpace, DstAddressSpace>) const
{
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
SrcData p_src_buffer_[buffer_desc.GetElementSpace()];
SrcData* p_src_buffer = p_src_buffer_;
// copy data from src into buffer
{
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
constexpr auto src_access_lengths = SliceLengths::Modify(
src_vector_access_dim,
SliceLengths::Get(src_vector_access_dim) / src_data_per_access);
// Offset w.r.t merged dimensions need to be calculated at run-time. Offset w.r.t
// normal dimensions is known at compile time.
// Below is a hack to isolate merged dimension id from normal dimension id, so the
// corresponding offset can be calculated seperately at run-time and compile-time.
// src_merged_dim_access_lengths has the same value as src_access_lengths on src's
// merged dimensions, and has value = 1 on normal dimensions;
// src_merged_dim_access_lengths has the same value as src_access_lengths on src's
// normal dimensions, and has value = 1 on merged dimensions;
constexpr auto src_merged_dim_access_lengths = typename sequence_gen<
nDim,
IsolateMergedDimLengths<SrcDesc, decltype(src_access_lengths)>>::type{};
constexpr auto src_normal_dim_access_lengths =
src_access_lengths + Number<1>{} - src_merged_dim_access_lengths;
ford<decltype(src_merged_dim_access_lengths), SrcDimAccessOrder>{}(
[&](auto src_merged_dim_access_id) {
auto src_merged_dim_data_id = src_merged_dim_access_id;
src_merged_dim_data_id(src_vector_access_dim) =
src_merged_dim_access_id[src_vector_access_dim] * src_data_per_access;
// offset w.r.t. merged dimension need be computed at run-time,
const index_t src_merged_offset =
(mSrcSliceOrigin + src_merged_dim_data_id).GetOffset();
ford<decltype(src_normal_dim_access_lengths), SrcDimAccessOrder>{}([&](
auto src_normal_dim_access_id) {
auto src_normal_dim_data_id = src_normal_dim_access_id;
src_normal_dim_data_id(src_vector_access_dim) =
src_normal_dim_access_id[src_vector_access_dim] * src_data_per_access;
// offset w.r.t. normal dimension is known at compile-time
const index_t src_normal_offset =
SrcDesc::GetOffsetFromMultiIndex(src_normal_dim_data_id);
src_vector_t vector_data;
// Read vector from src.
// 1. Source code version can take src of all kinds of memory-space
// 2. Intrinsic version using buffer_load can only take
// src from global-memory
//
// Commemt for loading from global-memory:
// When:
// 1) using source code, in order for compiler to emit optimal
// load instruction, or
// 2) using buffer_load intrinsic, in order for ISA to be valid,
// following assumptions need to be satisfied:
// 1. p_src need to be block-invariant (assumption)
// 2. src_normal_offset must be calculatd at compile time (guaranteed by
// algorithm)
// 3. src_merged_offset can be runtime value (no assumption imposed)
static_if<SrcAddressSpace == AddressSpace::Global>{}([&](auto fwd) {
#if CK_USE_AMD_BUFFER_ADDRESSING
vector_data = amd_intrinsic_buffer_load<SrcData, SrcDataPerAccess>(
fwd(p_src), src_merged_offset, src_normal_offset);
#else
vector_data = *reinterpret_cast<const src_vector_t*>(
&p_src[src_normal_offset + src_merged_offset]);
#endif
}).Else([&](auto) {
// src can be all kinds of memory-space.
vector_data = *reinterpret_cast<const src_vector_t*>(
&p_src[src_normal_offset + src_merged_offset]);
});
// unpack vector into buffer
for(index_t i = 0; i < SrcDataPerAccess; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(src_vector_access_dim) = i;
const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
src_merged_dim_data_id + src_normal_dim_data_id + scalar_id);
p_src_buffer[buffer_offset] =
reinterpret_cast<const SrcData*>(&vector_data)[i];
}
});
});
}
// type conversion
// TODO: would compiler do a good job reusing register for buffer?
DstData p_dst_buffer_[buffer_desc.GetElementSpace()];
DstData* p_dst_buffer = p_dst_buffer_;
ford<SliceLengths>{}([&](auto idx) {
p_dst_buffer[buffer_desc.GetOffsetFromMultiIndex(idx)] =
type_convert<DstData>{}(p_src_buffer[buffer_desc.GetOffsetFromMultiIndex(idx)]);
});
// copy data from buffer into dst
{
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
constexpr auto dst_access_lengths = SliceLengths::Modify(
dst_vector_access_dim,
SliceLengths::Get(dst_vector_access_dim) / dst_data_per_access);
constexpr auto dst_merged_dim_access_lengths = typename sequence_gen<
nDim,
IsolateMergedDimLengths<DstDesc, decltype(dst_access_lengths)>>::type{};
constexpr auto dst_normal_dim_access_lengths =
dst_access_lengths + Number<1>{} - dst_merged_dim_access_lengths;
ford<decltype(dst_merged_dim_access_lengths), DstDimAccessOrder>{}([&](
auto dst_merged_dim_access_id) {
auto dst_merged_dim_data_id = dst_merged_dim_access_id;
dst_merged_dim_data_id(dst_vector_access_dim) =
dst_merged_dim_access_id[dst_vector_access_dim] * dst_data_per_access;
// offset w.r.t. merged dimension need be computed at run-time,
const index_t dst_merged_offset =
(mDstSliceOrigin + dst_merged_dim_data_id).GetOffset();
ford<decltype(dst_normal_dim_access_lengths), DstDimAccessOrder>{}([&](
auto dst_normal_dim_access_id) {
auto dst_normal_dim_data_id = dst_normal_dim_access_id;
dst_normal_dim_data_id(dst_vector_access_dim) =
dst_normal_dim_access_id[dst_vector_access_dim] * dst_data_per_access;
dst_vector_t vector_data;
// pack vector from buffer
for(index_t i = 0; i < DstDataPerAccess; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(dst_vector_access_dim) = i;
const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
dst_merged_dim_data_id + dst_normal_dim_data_id + scalar_id);
reinterpret_cast<DstData*>(&vector_data)[i] = p_dst_buffer[buffer_offset];
}
// offset w.r.t. normal dimension is known at compile-time
const index_t dst_normal_offset =
DstDesc::GetOffsetFromMultiIndex(dst_normal_dim_data_id);
// Write vector into dst.
// 1. Source code version can take dst of all kinds of memory-space
// 2. Intrinsic version using buffer_store can only take
// dst from global-memory
//
// Commemt for storing into global-memory:
// When:
// 1) using source code, in order for compiler to emit optimal
// store instruction, or
// 2) using buffer_store, intrinsic in order ISA to be valid
// following assumptions need to be satisfied:
// 1. p_dst need to be block-invariant (assumption)
// 2. dst_normal_offset must be calculatd at compile time (guaranteed by
// algorithm)
// 3. dst_merged_offset can be runtime value (no assumption imposed)
static_if<DstAddressSpace == AddressSpace::Global>{}([&](auto fwd) {
#if CK_USE_AMD_BUFFER_ADDRESSING
amd_intrinsic_buffer_store<DstData, DstDataPerAccess>(
vector_data, fwd(p_dst), dst_merged_offset, dst_normal_offset);
#else
*reinterpret_cast<dst_vector_t*>(
&p_dst[dst_normal_offset + dst_merged_offset]) = vector_data;
#endif
}).Else([&](auto) {
// dst can be all kinds of memory-space
*reinterpret_cast<dst_vector_t*>(
&p_dst[dst_normal_offset + dst_merged_offset]) = vector_data;
});
});
});
}
}
template <typename SrcData, typename DstData>
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
{
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::Generic>{};
Run(p_src, p_dst, generic_address_space, generic_address_space);
}
// T can be Sequence or Array
template <typename T, bool PositiveDirection>
__device__ void MoveSrcSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
{
static_if<PositiveDirection>{}([&](auto) {
mSrcSliceOrigin += step_sizes;
}).Else([&](auto) { mSrcSliceOrigin -= step_sizes; });
}
template <typename T, bool PositiveDirection>
__device__ void MoveDstSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
{
static_if<PositiveDirection>{}([&](auto) {
mDstSliceOrigin += step_sizes;
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
}
private:
SrcCoordinate mSrcSliceOrigin;
DstCoordinate mDstSliceOrigin;
};
} // namespace ck
#endif
...@@ -8,65 +8,149 @@ namespace ck { ...@@ -8,65 +8,149 @@ namespace ck {
// For 128bit SGPRs in buffer_load and buffer_store instructions // For 128bit SGPRs in buffer_load and buffer_store instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
template <typename T> template <typename T>
union BufferLoadStoreDwordConfig union BufferAddressConfig
{ {
int32x4_t data; int32x4_t data;
T* address[2]; T* address[2];
int32_t range[4]; int32_t range[4];
}; };
__device__ float __llvm_amdgcn_buffer_load(int32x4_t rsrc, __device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t rsrc,
index_t vindex, index_t vindex,
index_t offset, index_t offset,
bool glc, bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.f32"); bool slc) __asm("llvm.amdgcn.buffer.load.f32");
__device__ float2_t __llvm_amdgcn_buffer_loadx2(int32x4_t rsrc, __device__ float2_t
__llvm_amdgcn_buffer_load_f32x2(int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v2f32");
__device__ float4_t
__llvm_amdgcn_buffer_load_f32x4(int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v4f32");
__device__ half_t __llvm_amdgcn_buffer_load_f16(int32x4_t rsrc,
index_t vindex, index_t vindex,
index_t offset, index_t offset,
bool glc, bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v2f32"); bool slc) __asm("llvm.amdgcn.buffer.load.f16");
__device__ float4_t __llvm_amdgcn_buffer_loadx4(int32x4_t rsrc, __device__ half2_t __llvm_amdgcn_buffer_load_f16x2(int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v2f16");
__device__ half4_t __llvm_amdgcn_buffer_load_f16x4(int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v4f16");
__device__ ushort __llvm_amdgcn_buffer_load_bf16(int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.bf16");
__device__ ushort2_t
__llvm_amdgcn_buffer_load_bf16x2(int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v2bf16");
__device__ ushort4_t
__llvm_amdgcn_buffer_load_bf16x4(int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v4bf16");
__device__ void __llvm_amdgcn_buffer_store_f32(float vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.f32");
__device__ void __llvm_amdgcn_buffer_store_f32x2(float2_t vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v2f32");
__device__ void __llvm_amdgcn_buffer_store_f32x4(float4_t vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v4f32");
__device__ void __llvm_amdgcn_buffer_store_f16(half_t vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.f16");
__device__ void __llvm_amdgcn_buffer_store_f16x2(half2_t vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v2f16");
__device__ void __llvm_amdgcn_buffer_store_f16x4(half4_t vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v4f16");
__device__ void __llvm_amdgcn_buffer_store_bf16(ushort vdata,
int32x4_t rsrc,
index_t vindex, index_t vindex,
index_t offset, index_t offset,
bool glc, bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v4f32"); bool slc) __asm("llvm.amdgcn.buffer.store.bf16");
__device__ void __llvm_amdgcn_buffer_store(float vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.f32");
__device__ void __llvm_amdgcn_buffer_storex2(float2_t vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v2f32");
__device__ void __llvm_amdgcn_buffer_storex4(float4_t vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v4f32");
__device__ void __device__ void
__llvm_amdgcn_buffer_atomic_add(float vdata, __llvm_amdgcn_buffer_store_bf16x2(ushort2_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
index_t vindex, index_t vindex,
index_t offset, index_t offset,
bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32"); bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v2bf16");
__device__ void
__llvm_amdgcn_buffer_store_bf16x4(ushort4_t vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v4bf16");
__device__ void
__llvm_amdgcn_buffer_atomic_add_f32(float vdata,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32");
// buffer_load requires: // buffer_load requires:
// 1) p_src must be in global memory space, d_dst must be vgpr // 1) p_src must be in global memory space, d_dst must be vgpr
// 2) p_src to be a block-invariant pointer. // 2) p_src to be a block-invariant pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize> template <typename T, index_t VectorSize>
__device__ typename vector_type<T, VectorSize>::MemoryType amd_intrinsic_buffer_load( __device__ typename vector_type<T, VectorSize>::MemoryType amd_buffer_load(
const T* p_src_block, index_t src_thread_data_offset, index_t src_const_data_offset); const T* p_src_block, index_t src_thread_data_offset, index_t src_const_data_offset);
// buffer_store requires: // buffer_store requires:
...@@ -74,30 +158,44 @@ __device__ typename vector_type<T, VectorSize>::MemoryType amd_intrinsic_buffer_ ...@@ -74,30 +158,44 @@ __device__ typename vector_type<T, VectorSize>::MemoryType amd_intrinsic_buffer_
// 2) p_dst to be a block-invariant pointer. // 2) p_dst to be a block-invariant pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize> template <typename T, index_t VectorSize>
__device__ void __device__ void amd_buffer_store(const T* p_src,
amd_intrinsic_buffer_store(const typename vector_type<T, VectorSize>::MemoryType& src, T* p_dst_block,
T* p_dst_block, index_t dst_thread_data_offset,
index_t dst_thread_data_offset, index_t dst_const_data_offset);
index_t dst_const_data_offset);
template <typename T, index_t VectorSize> template <typename T, index_t VectorSize>
__device__ void __device__ void amd_buffer_atomic_add(const T* p_src,
amd_intrinsic_buffer_atomic_add(const typename vector_type<T, VectorSize>::MemoryType& src, T* p_dst_block,
T* p_dst_block, index_t dst_thread_data_offset,
index_t dst_thread_data_offset, index_t dst_const_data_offset);
index_t dst_const_data_offset);
template <> template <>
__device__ float amd_intrinsic_buffer_load<float, 1>(const float* p_src_block, __device__ float amd_buffer_load<float, 1>(const float* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset)
{ {
float dst; BufferAddressConfig<float> src_block_config;
// fill in byte 0 - 1
src_block_config.address[0] = const_cast<float*>(p_src_block);
// fill in byte 2
src_block_config.range[2] = -1;
// fill in byte 3
src_block_config.range[3] = 0x00027000;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
index_t src_const_addr_offset = src_const_data_offset * sizeof(float); index_t src_const_addr_offset = src_const_data_offset * sizeof(float);
BufferLoadStoreDwordConfig<float> src_block_config; return __llvm_amdgcn_buffer_load_f32(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
}
template <>
__device__ float2_t amd_buffer_load<float, 2>(const float* p_src_block,
index_t src_thread_data_offset,
index_t src_const_data_offset)
{
BufferAddressConfig<float> src_block_config;
// fill in byte 0 - 1 // fill in byte 0 - 1
src_block_config.address[0] = const_cast<float*>(p_src_block); src_block_config.address[0] = const_cast<float*>(p_src_block);
...@@ -106,102 +204,283 @@ __device__ float amd_intrinsic_buffer_load<float, 1>(const float* p_src_block, ...@@ -106,102 +204,283 @@ __device__ float amd_intrinsic_buffer_load<float, 1>(const float* p_src_block,
// fill in byte 3 // fill in byte 3
src_block_config.range[3] = 0x00027000; src_block_config.range[3] = 0x00027000;
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
dst = __llvm_amdgcn_buffer_load( index_t src_const_addr_offset = src_const_data_offset * sizeof(float);
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
#else
asm volatile(
"\n \
buffer_load_dword %0, %1, %2, %3 offen offset:0 \n \
s_waitcnt 0 \n \
"
: "=v"(dst)
: "v"(src_thread_addr_offset), "s"(src_block_config.data), "s"(src_const_addr_offset));
#endif
return dst; return __llvm_amdgcn_buffer_load_f32x2(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
} }
template <> template <>
__device__ float2_t amd_intrinsic_buffer_load<float, 2>(const float* p_src_block, __device__ float4_t amd_buffer_load<float, 4>(const float* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset)
{ {
float2_t dst; BufferAddressConfig<float> src_block_config;
// fill in byte 0 - 1
src_block_config.address[0] = const_cast<float*>(p_src_block);
// fill in byte 2
src_block_config.range[2] = -1;
// fill in byte 3
src_block_config.range[3] = 0x00027000;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
index_t src_const_addr_offset = src_const_data_offset * sizeof(float); index_t src_const_addr_offset = src_const_data_offset * sizeof(float);
BufferLoadStoreDwordConfig<float> src_block_config; return __llvm_amdgcn_buffer_load_f32x4(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
}
template <>
__device__ half_t amd_buffer_load<half_t, 1>(const half_t* p_src_block,
index_t src_thread_data_offset,
index_t src_const_data_offset)
{
BufferAddressConfig<half_t> src_block_config;
// fill in byte 0 - 1 // fill in byte 0 - 1
src_block_config.address[0] = const_cast<float*>(p_src_block); src_block_config.address[0] = const_cast<half_t*>(p_src_block);
// fill in byte 2 // fill in byte 2
src_block_config.range[2] = -1; src_block_config.range[2] = -1;
// fill in byte 3 // fill in byte 3
src_block_config.range[3] = 0x00027000; src_block_config.range[3] = 0x00027000;
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC #if !CK_WORKAROUND_SWDEV_231101
dst = __llvm_amdgcn_buffer_loadx2( index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t);
return __llvm_amdgcn_buffer_load_f16(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
#else #else
asm volatile( return p_src_block[src_thread_data_offset + src_const_data_offset];
"\n \
buffer_load_dwordx2 %0, %1, %2, %3 offen offset:0 \n \
s_waitcnt 0 \n \
"
: "=v"(dst)
: "v"(src_thread_addr_offset), "s"(src_block_config.data), "s"(src_const_addr_offset));
#endif #endif
}
return dst; template <>
__device__ half2_t amd_buffer_load<half_t, 2>(const half_t* p_src_block,
index_t src_thread_data_offset,
index_t src_const_data_offset)
{
BufferAddressConfig<half_t> src_block_config;
// fill in byte 0 - 1
src_block_config.address[0] = const_cast<half_t*>(p_src_block);
// fill in byte 2
src_block_config.range[2] = -1;
// fill in byte 3
src_block_config.range[3] = 0x00027000;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t);
#if !CK_WORKAROUND_SWDEV_231101
return __llvm_amdgcn_buffer_load_f16x2(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
#else
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
return *reinterpret_cast<half2_t*>(&dst_out_tmp);
#endif
} }
template <> template <>
__device__ float4_t amd_intrinsic_buffer_load<float, 4>(const float* p_src_block, __device__ half4_t amd_buffer_load<half_t, 4>(const half_t* p_src_block,
index_t src_thread_data_offset, index_t src_thread_data_offset,
index_t src_const_data_offset) index_t src_const_data_offset)
{ {
float4_t dst; BufferAddressConfig<half_t> src_block_config;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); // fill in byte 0 - 1
index_t src_const_addr_offset = src_const_data_offset * sizeof(float); src_block_config.address[0] = const_cast<half_t*>(p_src_block);
// fill in byte 2
src_block_config.range[2] = -1;
// fill in byte 3
src_block_config.range[3] = 0x00027000;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t);
BufferLoadStoreDwordConfig<float> src_block_config; #if !CK_WORKAROUND_SWDEV_231101
return __llvm_amdgcn_buffer_load_f16x4(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
#else
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
return *reinterpret_cast<half4_t*>(&dst_out_tmp);
#endif
}
template <>
__device__ half8_t amd_buffer_load<half_t, 8>(const half_t* p_src_block,
index_t src_thread_data_offset,
index_t src_const_data_offset)
{
BufferAddressConfig<half_t> src_block_config;
// fill in byte 0 - 1 // fill in byte 0 - 1
src_block_config.address[0] = const_cast<float*>(p_src_block); src_block_config.address[0] = const_cast<half_t*>(p_src_block);
// fill in byte 2
src_block_config.range[2] = -1;
// fill in byte 3
src_block_config.range[3] = 0x00027000;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t);
#if !CK_WORKAROUND_SWDEV_231101
static_assert(false, "wrong! not supported");
#else
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
return *reinterpret_cast<half8_t*>(&dst_out_tmp);
#endif
}
template <>
__device__ ushort amd_buffer_load<ushort, 1>(const ushort* p_src_block,
index_t src_thread_data_offset,
index_t src_const_data_offset)
{
BufferAddressConfig<ushort> src_block_config;
// fill in byte 0 - 1
src_block_config.address[0] = const_cast<ushort*>(p_src_block);
// fill in byte 2 // fill in byte 2
src_block_config.range[2] = -1; src_block_config.range[2] = -1;
// fill in byte 3 // fill in byte 3
src_block_config.range[3] = 0x00027000; src_block_config.range[3] = 0x00027000;
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC #if !CK_WORKAROUND_SWDEV_231101
dst = __llvm_amdgcn_buffer_loadx4( index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort);
return __llvm_amdgcn_buffer_load_bf16(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
#else #else
asm volatile( return p_src_block[src_thread_data_offset + src_const_data_offset];
"\n \
buffer_load_dwordx4 %0, %1, %2, %3 offen offset:0 \n \
s_waitcnt 0 \n \
"
: "=v"(dst)
: "v"(src_thread_addr_offset), "s"(src_block_config.data), "s"(src_const_addr_offset));
#endif #endif
}
template <>
__device__ ushort2_t amd_buffer_load<ushort, 2>(const ushort* p_src_block,
index_t src_thread_data_offset,
index_t src_const_data_offset)
{
BufferAddressConfig<ushort> src_block_config;
return dst; // fill in byte 0 - 1
src_block_config.address[0] = const_cast<ushort*>(p_src_block);
// fill in byte 2
src_block_config.range[2] = -1;
// fill in byte 3
src_block_config.range[3] = 0x00027000;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort);
#if !CK_WORKAROUND_SWDEV_231101
return __llvm_amdgcn_buffer_load_bf16x2(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
#else
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
return *reinterpret_cast<ushort2_t*>(&dst_out_tmp);
#endif
} }
template <> template <>
__device__ void amd_intrinsic_buffer_store<float, 1>(const float& src, __device__ ushort4_t amd_buffer_load<ushort, 4>(const ushort* p_src_block,
float* p_dst_block, index_t src_thread_data_offset,
index_t dst_thread_data_offset, index_t src_const_data_offset)
index_t dst_const_data_offset)
{ {
BufferAddressConfig<ushort> src_block_config;
// fill in byte 0 - 1
src_block_config.address[0] = const_cast<ushort*>(p_src_block);
// fill in byte 2
src_block_config.range[2] = -1;
// fill in byte 3
src_block_config.range[3] = 0x00027000;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort);
#if !CK_WORKAROUND_SWDEV_231101
return __llvm_amdgcn_buffer_load_bf16x4(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
#else
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
return *reinterpret_cast<ushort4_t*>(&dst_out_tmp);
#endif
}
template <>
__device__ ushort8_t amd_buffer_load<ushort, 8>(const ushort* p_src_block,
index_t src_thread_data_offset,
index_t src_const_data_offset)
{
BufferAddressConfig<ushort> src_block_config;
// fill in byte 0 - 1
src_block_config.address[0] = const_cast<ushort*>(p_src_block);
// fill in byte 2
src_block_config.range[2] = -1;
// fill in byte 3
src_block_config.range[3] = 0x00027000;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort);
#if !CK_WORKAROUND_SWDEV_231101
static_assert(false, "wrong! not implemented");
#else
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
return *reinterpret_cast<ushort8_t*>(&dst_out_tmp);
#endif
}
template <>
__device__ void amd_buffer_store<float, 1>(const float* p_src,
float* p_dst_block,
index_t dst_thread_data_offset,
index_t dst_const_data_offset)
{
BufferAddressConfig<float> dst_block_config;
// fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block;
// fill in byte 2
dst_block_config.range[2] = -1;
// fill in byte 3
dst_block_config.range[3] = 0x00027000;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float); index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
BufferLoadStoreDwordConfig<float> dst_block_config; __llvm_amdgcn_buffer_store_f32(*p_src,
dst_block_config.data,
0,
dst_thread_addr_offset + dst_const_addr_offset,
false,
false);
}
template <>
__device__ void amd_buffer_store<float, 2>(const float* p_src,
float* p_dst_block,
index_t dst_thread_data_offset,
index_t dst_const_data_offset)
{
BufferAddressConfig<float> dst_block_config;
// fill in byte 0 - 1 // fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block; dst_block_config.address[0] = p_dst_block;
...@@ -210,35 +489,50 @@ __device__ void amd_intrinsic_buffer_store<float, 1>(const float& src, ...@@ -210,35 +489,50 @@ __device__ void amd_intrinsic_buffer_store<float, 1>(const float& src,
// fill in byte 3 // fill in byte 3
dst_block_config.range[3] = 0x00027000; dst_block_config.range[3] = 0x00027000;
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
__llvm_amdgcn_buffer_store(src, index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
dst_block_config.data,
0, __llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast<const float2_t*>(p_src),
dst_thread_addr_offset + dst_const_addr_offset, dst_block_config.data,
false, 0,
false); dst_thread_addr_offset + dst_const_addr_offset,
#else false,
asm volatile("\n \ false);
buffer_store_dword %1, %2, %0, %3 offen offset:0 \n \
"
:
: "s"(dst_block_config.data),
"v"(src),
"v"(dst_thread_addr_offset),
"s"(dst_const_addr_offset));
#endif
} }
template <> template <>
__device__ void amd_intrinsic_buffer_store<float, 2>(const float2_t& src, __device__ void amd_buffer_store<float, 4>(const float* p_src,
float* p_dst_block, float* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset)
{ {
BufferAddressConfig<float> dst_block_config;
// fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block;
// fill in byte 2
dst_block_config.range[2] = -1;
// fill in byte 3
dst_block_config.range[3] = 0x00027000;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float); index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
BufferLoadStoreDwordConfig<float> dst_block_config; __llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast<const float4_t*>(p_src),
dst_block_config.data,
0,
dst_thread_addr_offset + dst_const_addr_offset,
false,
false);
}
template <>
__device__ void amd_buffer_store<half_t, 1>(const half_t* p_src,
half_t* p_dst_block,
index_t dst_thread_data_offset,
index_t dst_const_data_offset)
{
BufferAddressConfig<half_t> dst_block_config;
// fill in byte 0 - 1 // fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block; dst_block_config.address[0] = p_dst_block;
...@@ -247,35 +541,68 @@ __device__ void amd_intrinsic_buffer_store<float, 2>(const float2_t& src, ...@@ -247,35 +541,68 @@ __device__ void amd_intrinsic_buffer_store<float, 2>(const float2_t& src,
// fill in byte 3 // fill in byte 3
dst_block_config.range[3] = 0x00027000; dst_block_config.range[3] = 0x00027000;
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC #if !CK_WORKAROUND_SWDEV_231101
__llvm_amdgcn_buffer_storex2(src, index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
dst_block_config.data, index_t dst_const_addr_offset = dst_const_data_offset * sizeof(half_t);
0,
dst_thread_addr_offset + dst_const_addr_offset, __llvm_amdgcn_buffer_store_f16(*p_src,
false, dst_block_config.data,
false); 0,
dst_thread_addr_offset + dst_const_addr_offset,
false,
false);
#else #else
asm volatile("\n \ p_dst_block[dst_thread_data_offset + dst_const_data_offset] = *p_src;
buffer_store_dwordx2 %1, %2, %0, %3 offen offset:0 \n \
"
:
: "s"(dst_block_config.data),
"v"(src),
"v"(dst_thread_addr_offset),
"s"(dst_const_addr_offset));
#endif #endif
} }
template <> template <>
__device__ void amd_intrinsic_buffer_store<float, 4>(const float4_t& src, __device__ void amd_buffer_store<half_t, 2>(const half_t* p_src,
float* p_dst_block, half_t* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset)
{ {
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); BufferAddressConfig<half_t> dst_block_config;
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
// fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block;
// fill in byte 2
dst_block_config.range[2] = -1;
// fill in byte 3
dst_block_config.range[3] = 0x00027000;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(half_t);
BufferLoadStoreDwordConfig<float> dst_block_config; #if !CK_WORKAROUND_SWDEV_231101
__llvm_amdgcn_buffer_store_f16x2(*reinterpret_cast<const half2_t*>(p_src),
dst_block_config.data,
0,
dst_thread_addr_offset + dst_const_addr_offset,
false,
false);
#else
const float* p_src_tmp = reinterpret_cast<const float*>(p_src);
__llvm_amdgcn_buffer_store_f32(*p_src_tmp,
dst_block_config.data,
0,
dst_thread_addr_offset + dst_const_addr_offset,
false,
false);
#endif
}
template <>
__device__ void amd_buffer_store<half_t, 4>(const half_t* p_src,
half_t* p_dst_block,
index_t dst_thread_data_offset,
index_t dst_const_data_offset)
{
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(half_t);
BufferAddressConfig<half_t> dst_block_config;
// fill in byte 0 - 1 // fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block; dst_block_config.address[0] = p_dst_block;
...@@ -284,35 +611,99 @@ __device__ void amd_intrinsic_buffer_store<float, 4>(const float4_t& src, ...@@ -284,35 +611,99 @@ __device__ void amd_intrinsic_buffer_store<float, 4>(const float4_t& src,
// fill in byte 3 // fill in byte 3
dst_block_config.range[3] = 0x00027000; dst_block_config.range[3] = 0x00027000;
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC #if !CK_WORKAROUND_SWDEV_231101
__llvm_amdgcn_buffer_storex4(src, __llvm_amdgcn_buffer_store_f16x4(*reinterpret_cast<const half4_t*>(p_src),
dst_block_config.data, dst_block_config.data,
0, 0,
dst_thread_addr_offset + dst_const_addr_offset, dst_thread_addr_offset + dst_const_addr_offset,
false, false,
false); false);
#else #else
asm volatile("\n \ const float2_t* p_src_tmp = reinterpret_cast<const float2_t*>(p_src);
buffer_store_dwordx4 %1, %2, %0, %3 offen offset:0 \n \
" __llvm_amdgcn_buffer_store_f32x2(*p_src_tmp,
: dst_block_config.data,
: "s"(dst_block_config.data), 0,
"v"(src), dst_thread_addr_offset + dst_const_addr_offset,
"v"(dst_thread_addr_offset), false,
"s"(dst_const_addr_offset)); false);
#endif #endif
} }
template <> template <>
__device__ void amd_intrinsic_buffer_atomic_add<float, 1>(const float& src, __device__ void amd_buffer_store<ushort, 1>(const ushort* p_src,
float* p_dst_block, ushort* p_dst_block,
index_t dst_thread_data_offset, index_t dst_thread_data_offset,
index_t dst_const_data_offset) index_t dst_const_data_offset)
{ {
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); BufferAddressConfig<ushort> dst_block_config;
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
// fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block;
// fill in byte 2
dst_block_config.range[2] = -1;
// fill in byte 3
dst_block_config.range[3] = 0x00027000;
#if !CK_WORKAROUND_SWDEV_231101
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(ushort);
__llvm_amdgcn_buffer_store_bf16(*p_src,
dst_block_config.data,
0,
dst_thread_addr_offset + dst_const_addr_offset,
false,
false);
#else
p_dst_block[dst_thread_data_offset + dst_const_data_offset] = *p_src;
#endif
}
template <>
__device__ void amd_buffer_store<ushort, 2>(const ushort* p_src,
ushort* p_dst_block,
index_t dst_thread_data_offset,
index_t dst_const_data_offset)
{
BufferAddressConfig<ushort> dst_block_config;
// fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block;
// fill in byte 2
dst_block_config.range[2] = -1;
// fill in byte 3
dst_block_config.range[3] = 0x00027000;
BufferLoadStoreDwordConfig<float> dst_block_config; index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(ushort);
#if !CK_WORKAROUND_SWDEV_231101
__llvm_amdgcn_buffer_store_bf16x2(*p_src,
dst_block_config.data,
0,
dst_thread_addr_offset + dst_const_addr_offset,
false,
false);
#else
const float* p_src_tmp = reinterpret_cast<const float*>(p_src);
__llvm_amdgcn_buffer_store_f32(*p_src_tmp,
dst_block_config.data,
0,
dst_thread_addr_offset + dst_const_addr_offset,
false,
false);
#endif
}
template <>
__device__ void amd_buffer_store<ushort, 4>(const ushort* p_src,
ushort* p_dst_block,
index_t dst_thread_data_offset,
index_t dst_const_data_offset)
{
BufferAddressConfig<ushort> dst_block_config;
// fill in byte 0 - 1 // fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block; dst_block_config.address[0] = p_dst_block;
...@@ -321,13 +712,75 @@ __device__ void amd_intrinsic_buffer_atomic_add<float, 1>(const float& src, ...@@ -321,13 +712,75 @@ __device__ void amd_intrinsic_buffer_atomic_add<float, 1>(const float& src,
// fill in byte 3 // fill in byte 3
dst_block_config.range[3] = 0x00027000; dst_block_config.range[3] = 0x00027000;
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
__llvm_amdgcn_buffer_atomic_add( index_t dst_const_addr_offset = dst_const_data_offset * sizeof(ushort);
src, dst_block_config.data, 0, dst_thread_addr_offset + dst_const_addr_offset, false);
#if !CK_WORKAROUND_SWDEV_231101
__llvm_amdgcn_buffer_store_bf16x4(*p_src,
dst_block_config.data,
0,
dst_thread_addr_offset + dst_const_addr_offset,
false,
false);
#else #else
static_assert(false, " wrong! not implemented"); const float2_t* p_src_tmp = reinterpret_cast<const float2_t*>(p_src);
__llvm_amdgcn_buffer_store_f32x2(*p_src_tmp,
dst_block_config.data,
0,
dst_thread_addr_offset + dst_const_addr_offset,
false,
false);
#endif #endif
} }
template <>
__device__ void amd_buffer_atomic_add<float, 1>(const float* p_src,
float* p_dst_block,
index_t dst_thread_data_offset,
index_t dst_const_data_offset)
{
BufferAddressConfig<float> dst_block_config;
// fill in byte 0 - 1
dst_block_config.address[0] = p_dst_block;
// fill in byte 2
dst_block_config.range[2] = -1;
// fill in byte 3
dst_block_config.range[3] = 0x00027000;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
__llvm_amdgcn_buffer_atomic_add_f32(
*p_src, dst_block_config.data, 0, dst_thread_addr_offset + dst_const_addr_offset, false);
}
template <>
__device__ void amd_buffer_atomic_add<float, 2>(const float* p_src,
float* p_dst_block,
index_t dst_thread_data_offset,
index_t dst_const_data_offset)
{
for(index_t i = 0; i < 2; ++i)
{
amd_buffer_atomic_add<float, 1>(
&p_src[i], p_dst_block, dst_thread_data_offset, dst_const_data_offset + i);
}
}
template <>
__device__ void amd_buffer_atomic_add<float, 4>(const float* p_src,
float* p_dst_block,
index_t dst_thread_data_offset,
index_t dst_const_data_offset)
{
for(index_t i = 0; i < 4; ++i)
{
amd_buffer_atomic_add<float, 1>(
&p_src[i], p_dst_block, dst_thread_data_offset, dst_const_data_offset + i);
}
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -16,15 +16,12 @@ ...@@ -16,15 +16,12 @@
#include "functional3.hpp" #include "functional3.hpp"
#include "functional4.hpp" #include "functional4.hpp"
#include "in_memory_operation.hpp" #include "in_memory_operation.hpp"
#include "synchronization.hpp"
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp" #include "amd_inline_asm.hpp"
#endif #endif
#if CK_USE_AMD_BUFFER_ADDRESSING
#include "amd_buffer_addressing.hpp"
#endif
#if CK_USE_AMD_XDLOPS #if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp" #include "amd_xdlops.hpp"
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment