Commit 9b280cc5 authored by Chao Liu's avatar Chao Liu
Browse files

remove dead code

parent 98a2cfcc
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace ck {
// define B = merge(N0, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmNRepeat,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
typename InBlockCopySubLengths_E_N1_B_N2,
typename InBlockCopyClusterLengths_E_N1_B_N2,
typename InBlockCopyThreadClusterArrangeOrder,
typename InBlockCopySrcAccessOrder,
typename InBlockCopyDstAccessOrder,
index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2,
typename WeiBlockCopySubLengths_E_K,
typename WeiBlockCopyClusterLengths_E_K,
typename WeiBlockCopyThreadClusterArrangeOrder,
typename WeiBlockCopySrcAccessOrder,
typename WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I5 = Number<5>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2);
constexpr index_t B = N0 * Ho * Wo;
constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc),
Sequence<0, 1, 2>{},
Sequence<4>{},
Sequence<3, 6, 7>{},
Sequence<5>{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not satisfied");
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
2,
3,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc =
make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k2_n1n2_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block[in_block_space];
__shared__ Float p_wei_block[wei_block_space];
// register allocation for output
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
// do work
for(index_t e = 0; e < E; e += EPerBlock)
{
blockwise_in_copy.Run(p_in_global, p_in_block);
blockwise_wei_copy.Run(p_wei_global, p_wei_block);
__syncthreads();
blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread);
__syncthreads();
blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(EPerBlock, 0, 0, 0), True);
blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(EPerBlock, 0), True);
}
// copy output: register to global memory
{
#if 0
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
make_ConstantTensorDescriptor_packed(
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
// output tensor descriptor in register, src of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
.Fold(I0, Number<N1>{}, Number<N2>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
Sequence<3>{},
Sequence<1>{},
Sequence<0, 4, 5>{},
Sequence<2>{});
// origin of dst in device memory
Float* p_out_thread_on_global =
p_out_global +
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
ThreadwiseGenericTensorSliceCopy_v2r1<
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 8, 1>::type,
arithmetic_sequence_gen<0, 8, 1>::type,
7,
7,
1,
1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
.Run(p_out_thread, p_out_thread_on_global);
#else
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register, src of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
// output memory layout descriptor in device memory
constexpr auto out_n0_n1_n2_k0_k1_h_w_global_mem_desc =
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}).Fold(I0, Number<N1>{}, Number<N2>{});
// output merged global tensor descriptor, dst of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_global_merged_desc =
make_ConstantMergedTensorDescriptor(out_n0_n1_n2_k0_k1_h_w_global_mem_desc,
Sequence<3>{},
Sequence<4>{},
Sequence<1>{},
Sequence<0, 5, 6>{},
Sequence<2>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
ThreadwiseGenericTensorSliceCopy_v2r1<
decltype(out_k0_k1_n1_b_n2_thread_mem_desc),
decltype(out_k0_k1_n1_b_n2_global_merged_desc),
decltype(out_k0_k1_n1_b_n2_thread_mem_desc.GetLengths()),
arithmetic_sequence_gen<0, 5, 1>::type,
arithmetic_sequence_gen<0, 5, 1>::type,
3,
3,
1,
1>({0, 0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
0,
b_thread_data_on_global,
0})
.template Run_amd_experiment<Float, 0, 2>(p_out_thread, p_out_global);
#endif
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace ck {
// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_B,
class InBlockCopyClusterLengths_E_B,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B>
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I5 = Number<5>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N, Ho, Wo]
constexpr auto in_n_ho_wo_global_desc =
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, B], src of blockwise copy
constexpr auto in_e_b_global_desc =
make_ConstantMergedTensorDescriptor(in_c_y_x_global_desc.Embed(in_n_ho_wo_global_desc),
Sequence<0, 1, 2>{},
Sequence<3, 4, 5>{});
// memory layout descriptor in LDS [E, B], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_b_block_desc =
make_ConstantTensorDescriptor_packed(Sequence<EPerBlock, BPerBlock>{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc),
decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
1,
1,
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
// sanity check
static_assert(
KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
BPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
constexpr index_t GemmNRepeat =
BPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k1_b0b1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_b_block_mtx_desc),
decltype(c_k0k1_b0b1_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block[in_block_space];
__shared__ Float p_wei_block[wei_block_space];
// register allocation for output
Float p_out_thread[c_k0k1_b0b1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
for(index_t e_block_data_begin = 0; e_block_data_begin < E; e_block_data_begin += EPerBlock)
{
blockwise_in_copy.Run(p_in_global, p_in_block);
blockwise_wei_copy.Run(p_wei_global, p_wei_block);
__syncthreads();
blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread);
__syncthreads();
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
}
// copy output: register to global memory
{
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
// define tensor descriptor for threadwise copy
// output global descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_b_global_desc = make_ConstantMergedTensorDescriptor(
out_n_k_h_w_global_desc, Sequence<1>{}, Sequence<0, 2, 3>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
// This is a hack, because slicing a merged dimension is not supported yet.
// This should be replaced with logic above, once slicing a merged dimension support
// become available
// dst descriptor
constexpr auto out_k0_k1_b_global_desc =
make_ConstantMergedTensorDescriptor(out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}),
Sequence<1>{},
Sequence<2>{},
Sequence<0, 3, 4>{});
// src descriptor
constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat * GemmNPerThreadSubC>{});
using OutThreadCopySliceLengths =
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
auto threadwise_out_copy =
ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_k0_k1_b_thread_desc),
decltype(out_k0_k1_b_global_desc),
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 3, 1>::type,
arithmetic_sequence_gen<0, 3, 1>::type,
2,
2,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>(
{0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global});
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
{
threadwise_out_copy.Run(p_out_thread, p_out_global);
threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True);
threadwise_out_copy.MoveDstSliceWindow(Sequence<0, 0, B1>{}, True);
}
}
}
};
} // namespace ck
#endif
#pragma once
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_direct_convolution.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "threadwise_direct_convolution.hpp"
namespace ck {
template <class TInWei,
class TOut,
class TAccum,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
index_t ScalarPerVector,
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t CPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
index_t BlockSize,
index_t GridSize>
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
const typename vector_type<TInWei,
ScalarPerVector>::MemoryType* const __restrict__ p_in_vec_global,
const typename vector_type<TInWei,
ScalarPerVector>::MemoryType* const __restrict__ p_wei_vec_global,
TOut* const __restrict__ p_out_global)
{
using in_scalar_t = TInWei;
using in_vector_mem_t = typename vector_type<in_scalar_t, ScalarPerVector>::MemoryType;
using out_scalar_t = TOut;
using accum_t = TAccum;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_vec_global_desc = InGlobalDesc{};
constexpr auto wei_kcyx_vec_global_desc = WeiGlobalDesc{};
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
constexpr index_t N = in_nchw_vec_global_desc.GetLength(I0);
constexpr index_t K = wei_kcyx_vec_global_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_vec_global_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_vec_global_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_vec_global_desc.GetLength(I3);
constexpr auto wei_ke_vec_global_desc = make_ConstantTensorDescriptor(
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
constexpr auto in_nchw_vec_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Number<InBlockCopyDataPerRead>{});
constexpr auto wei_ke_vec_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<KPerBlock, CPerBlock * Y * X>{},
Number<WeiBlockCopyDataPerRead>{}); // 2d view of wei for blockwise copy
constexpr auto wei_kcyx_vec_block_desc =
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, Y, X>{},
Sequence<wei_ke_vec_block_desc.GetStride(I0), Y * X, X, 1>{});
// shared mem
constexpr index_t in_block_element_size =
in_nchw_vec_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr index_t wei_block_element_size =
wei_kcyx_vec_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
__shared__ in_vector_mem_t
p_in_vec_block[max_align * ((in_block_element_size + max_align - 1) / max_align)];
__shared__ in_vector_mem_t
p_wei_vec_block[max_align * ((wei_block_element_size + max_align - 1) / max_align)];
// threadwise tensors
constexpr index_t HiPerThread = HoPerThread + Y - 1;
constexpr index_t WiPerThread = WoPerThread + X - 1;
constexpr auto in_nchw_vec_thread_block_desc =
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
in_nchw_vec_block_desc.GetStrides());
constexpr auto wei_kcyx_vec_thread_block_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_vec_block_desc.GetStrides());
constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor(
in_nchw_vec_thread_block_desc, wei_kcyx_vec_thread_block_desc);
// register
out_scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
// divide block work
constexpr index_t NBlockWork = (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr index_t KBlockWork = (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork =
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork =
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
const index_t block_id = blockIdx.x;
index_t itmp = block_id;
const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
const index_t h_block_work_id = itmp / WBlockWork;
const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin; // minus padding
const index_t wi_block_data_begin = wo_block_data_begin; // minus padding
// divide thread work
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
const index_t thread_id = get_thread_local_1d_id();
itmp = thread_id;
const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork);
itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
const index_t h_thread_work_id = itmp / WThreadWork;
const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
const index_t n_thread_data_begin = n_thread_work_id * NPerThread;
const index_t k_thread_data_begin = k_thread_work_id * KPerThread;
const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread;
const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread;
const index_t hi_thread_data_begin = ho_thread_data_begin;
const index_t wi_thread_data_begin = wo_thread_data_begin;
constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize,
in_vector_mem_t,
decltype(in_nchw_vec_global_desc),
decltype(in_nchw_vec_block_desc),
decltype(in_nchw_vec_block_desc.GetLengths()),
InBlockCopyDataPerRead>{};
#if 0
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
in_vector_mem_t,
decltype(wei_kcyx_vec_global_desc),
decltype(wei_kcyx_vec_block_desc),
decltype(wei_kcyx_vec_block_desc.GetLengths()),
1>{};
#elif 1
const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize,
in_vector_mem_t,
decltype(wei_ke_vec_global_desc),
decltype(wei_ke_vec_block_desc),
decltype(wei_ke_vec_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{};
#endif
#if 1 // debug
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
#endif
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock, __syncthreads())
{
// copy input tensor to LDS
blockwise_in_copy.Run(
p_in_vec_global +
in_nchw_vec_global_desc.GetOffsetFromMultiIndex(n_block_data_begin,
c_block_data_begin,
hi_block_data_begin,
wi_block_data_begin),
p_in_vec_block);
// copy weight tensor to LDS
blockwise_wei_copy.Run(p_wei_vec_global +
wei_kcyx_vec_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_vec_block);
__syncthreads();
for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
{
// threadwise convolution
#if 1
threadwise_direct_convolution_2(
in_nchw_vec_thread_block_desc,
p_in_vec_block +
in_nchw_vec_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
c_thread_data,
hi_thread_data_begin,
wi_thread_data_begin),
wei_kcyx_vec_thread_block_desc,
p_wei_vec_block +
wei_kcyx_vec_block_desc.GetOffsetFromMultiIndex(
k_thread_data_begin, c_thread_data, 0, 0),
out_nkhw_thread_desc,
p_out_thread);
#elif 0
threadwise_direct_convolution_3(
in_nchw_vec_thread_block_desc,
p_in_vec_block +
in_nchw_vec_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
c_thread_data,
hi_thread_data_begin,
wi_thread_data_begin),
wei_kcyx_vec_thread_block_desc,
p_wei_vec_block +
wei_kcyx_vec_block_desc.GetOffsetFromMultiIndex(
k_thread_data_begin, c_thread_data, 0, 0),
out_nkhw_thread_desc,
p_out_thread);
#endif
}
}
// copy output tensor from register to global mem
threadwise_4d_tensor_copy(out_nkhw_thread_desc,
p_out_thread,
out_nkhw_global_desc,
p_out_global +
out_nkhw_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_nkhw_thread_desc.GetLengths());
}
} // namespace ck
#pragma once
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class LowerPads,
class UpperPads,
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t CPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t WeiBlockCopyThreadPerDim0,
index_t WeiBlockCopyThreadPerDim1>
__global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global)
{
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0");
static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_chwn_global_desc = InGlobalDesc{};
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
constexpr index_t N = out_khwn_global_desc.GetLength(I3);
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
constexpr index_t HPadLow = LowerPads{}.Get(I0);
constexpr index_t WPadLow = LowerPads{}.Get(I1);
constexpr index_t HPadUp = UpperPads{}.Get(I0);
constexpr index_t WPadUp = UpperPads{}.Get(I1);
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
// divide block work: [K, Ho, Wo, N]
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
// flattened (2d) tensor view of wei in global mem
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
// tensor view of blockwise input and weight in LDS
constexpr auto in_chwn_block_desc =
make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{});
constexpr auto wei_cyxk_block_desc =
make_ConstantTensorDescriptor(Sequence<CPerBlock, Y, X, KPerBlock>{});
// flattened (2d) tensor view of wei in LDS
constexpr auto wei_ek_block_desc =
make_ConstantTensorDescriptor(Sequence<CPerBlock * Y * X, KPerBlock>{});
// tensor view of threadwise output in register
constexpr auto out_hkwn_thread_desc =
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
print_ConstantTensorDescriptor(wei_cyxk_block_desc, "wei_cyxk_block_desc");
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
}
#endif
// blockwise copy
// input: format is [C, Hi, Wi, N]
const index_t h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0;
const index_t w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0;
const index_t h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0;
const index_t w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0;
#if 0
if(get_thread_local_1d_id() == 0)
;
{
printf(
"%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
constexpr auto blockwise_in_copy =
BlockwiseChwnTensorCopyPadded<BlockSize,
Float,
decltype(in_chwn_global_desc),
decltype(in_chwn_block_desc),
decltype(in_chwn_block_desc.GetLengths()),
LowerPads>{};
#if 0
// weight: format is [C,Y,X,K]
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(wei_cyxk_global_desc),
decltype(wei_cyxk_block_desc),
decltype(wei_cyxk_block_desc.GetLengths())>{};
#elif 0
// weight: format is [C*Y*X,K]
constexpr auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif 1
// weight: format is [C*Y*X,K]
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1>{};
#endif
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_cyxk_block_desc.GetStride(I0)>{});
constexpr auto b_cxwn_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{},
Number<in_chwn_block_desc.GetStride(I0)>{});
constexpr auto c_kxwn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<WoPerThread * NPerThread>{});
const auto blockwise_batch_gemm =
Blockwise1dStridedBatchedGemmBlockABlockBThreadC<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxwn_block_mtx_desc),
decltype(c_kxwn_thread_mtx_desc),
true,
false,
false,
0,
in_chwn_block_desc.GetStride(I1),
out_hkwn_thread_desc.GetStride(I0),
HoPerBlock,
HoPerThread,
CPerThread,
true>{};
// LDS
constexpr index_t in_block_element_size = in_chwn_block_desc.GetElementSpace();
constexpr index_t wei_block_element_size = wei_cyxk_block_desc.GetElementSpace();
__shared__ Float p_in_block[in_block_element_size];
__shared__ Float p_wei_block[wei_block_element_size];
// register
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread);
const Float* p_wei_global_block_begin =
p_wei_global + wei_ek_global_desc.GetOffsetFromMultiIndex(0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0),
__syncthreads())
{
#if 1
// input: global mem to LDS,
blockwise_in_copy.Run(p_in_global,
c_block_data_begin,
ho_block_data_begin,
wo_block_data_begin,
n_block_data_begin,
p_in_block,
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
#endif
#if 1
// weight: global mem to LDS,
blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block);
#endif
__syncthreads();
// a series of batched GEMM
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
blockwise_batch_gemm.Run(
p_wei_block + wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_in_block + in_chwn_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_out_thread,
f_accum);
}
}
}
const auto matrix_c_index =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t ho_thread_data_begin = matrix_c_index.batch;
const index_t k_thread_data_begin = matrix_c_index.row;
const index_t wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const index_t n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
get_block_1d_id(), get_thread_local_1d_id(),
ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin,
ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin,
p_out_thread[0]);
#endif
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{};
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
out_hkwn_thread_desc,
p_out_thread,
out_khwn_global_desc,
p_out_global +
out_khwn_global_desc.GetOffsetFromMultiIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin),
out_hkwn_thread_desc.GetLengths(),
reorder_khwn_from_hkwn);
}
} // namespace ck
......@@ -5,12 +5,6 @@
namespace ck {
template <index_t Length>
struct Dimension
{
__host__ __device__ static constexpr auto GetLength() { return Number<Length>{}; }
};
template <index_t Length, index_t Stride>
struct NativeDimension
{
......
#ifndef CK_TENSOR_VIEW_HPP
#define CK_TENSOR_VIEW_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "tensor_coordinate_deprecated.hpp"
namespace ck {
// TensorDesc is ConstantTensorDescriptor or ConstantMergedTensorDescriptor
template <class TensorDesc, class TData>
struct NormalTensorView
{
using type = NormalTensorView;
using tensor_desc_type = TensorDesc;
using coordinate_type = typename NormalTensorCoordinate_deprecated<TensorDesc>::type;
using data_type = TData;
static constexpr auto nDim = TensorDesc::GetNumOfDimension();
__host__ __device__ constexpr NormalTensorView(TData* p_data) : mpData{p_data} {}
__host__ __device__ constexpr NormalTensorView() : NormalTensorView{nullptr} {}
__host__ __device__ static constexpr auto GetNumOfDimension() { return nDim; }
__host__ __device__ static constexpr auto GetLengths() { return TensorDesc::GetLengths(); }
__host__ __device__ const TData& operator[](coordinate_type coord) const
{
return mpData[coord.GetOffset()];
}
__host__ __device__ TData& operator()(coordinate_type coord) const
{
return mpData[coord.GetOffset()];
}
template <class IDim, class DataPerVector>
__host__ __device__ static constexpr auto IsVectorizationAllowed(IDim, DataPerVector)
{
return TensorDesc::IsVectorizationAllowed(IDim{}, DataPerVector{});
}
template <class IDim, class DataPerVector>
__host__ __device__ auto Vectorize(IDim idim, DataPerVector data_per_vector) const
{
static_assert(IsVectorizationAllowed(idim, data_per_vector), "wrong!");
using vector_t = typename vector_type<TData, data_per_vector>::MemoryType;
return NormalTensorView<decltype(TensorDesc::Vectorize(idim, data_per_vector)), vector_t>(
reinterpret_cast<vector_t*>(mpData));
}
template <index_t... Is>
__host__ __device__ auto Slice(coordinate_type slice_origin, Sequence<Is...> slice_lengths)
{
static_assert(slice_lengths.GetSize() == nDim, "wrong!");
return NormalTensorView<decltype(TensorDesc::Slice(slice_lengths)), TData>(
mpData + slice_origin.GetOffset());
}
template <class IDim, class SliceLen>
__host__ __device__ auto
Slice(coordinate_type slice_origin, IDim idim, SliceLen slice_len) const
{
return NormalTensorView<decltype(TensorDesc::Slice(idim, slice_len)), TData>(
mpData + slice_origin.GetOffset());
}
// slice_window is a slicing window on "*this"
template <class SliceWindow, class T, bool PositiveDirection>
__device__ void MoveSliceWindow(SliceWindow& slice_window,
T step_sizes,
integral_constant<bool, PositiveDirection>)
{
if(PositiveDirection)
{
slice_window.mpData += coordinate_type{step_sizes}.GetOffset();
}
else
{
slice_window.mpData -= coordinate_type{step_sizes}.GetOffset();
}
}
// private:
data_type* mpData;
};
template <class... Xs, class TData>
__host__ __device__ constexpr auto make_TensorView(ConstantTensorDescriptor<Xs...>, TData* p_data)
{
return NormalTensorView<ConstantTensorDescriptor<Xs...>, TData>{p_data};
}
} // namespace ck
#endif
#ifndef CK_TENSOR_VISIT_HPP
#define CK_TENSOR_VISIT_HPP
#include "common_header.hpp"
#include "dimension.hpp"
#include "dimension_transform.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_coordinate.hpp"
namespace ck {
template <class TensorDescriptor>
struct TensorVisit
{
using Index = typename TensorDescriptor::Index;
using Coordinate = typename TensorCoordinate<TensorDescriptor>::type;
__host__ __device__ static void Run_v1(Index idx_begin)
{
const auto coord_begin = Coordinate(idx_begin);
ford<TensorDescriptor::GetLengths()>{}(
[&](auto idx_diff) { index_t offset = (coord_begin + idx_diff).GetOffset(); });
}
__host__ __device__ static void Run_v2(Index idx_begin)
{
const auto coord_begin = Coordinate(idx_begin);
ford<TensorDescriptor::GetLengths()>{}([&](auto idx_diff) {
index_t offset_diff = coord_begin.GetOffsetDiff(idx_diff);
index_t offset = coord_begin.GetOffset() + offset_diff;
});
}
__host__ __device__ static void Run_v3(Index idx_begin)
{
const auto coord_begin = Coordinate(idx_begin);
constexpr auto linear_dimensions = TensorDescriptor::GetLinearDimensions();
constexpr auto nonlinear_dimensions = TensorDescriptor::GetNonLinearDimensions();
constexpr auto lengths = TensorDescriptor::GetLengths();
constexpr auto linear_dimension_lengths_hack =
lambda_HackLengths{}(lengths, linear_dimensions);
constexpr auto nonlinear_dimension_lengths_hack =
lambda_HackLengths{}(lengths, nonlinear_dimensions);
ford<nonlinear_dimension_lengths_hack>{}([&](auto idx_diff_nonlinear_hack) {
// run-time component
index_t offset_diff_nonlinear = coord_begin.GetOffsetDiff(idx_diff_nonlinear_hack);
ford<linear_dimension_lengths_hack>{}([&](auto idx_diff_linear_hack) {
// compile-time component
index_t offset_diff_linear = coord_begin.GetOffsetDiff(idx_diff_linear_hack);
index_t offset =
coord_begin.GetOffset() + offset_diff_nonlinear + offset_diff_linear;
});
});
}
__host__ __device__ static void Run_v4(Index idx_begin)
{
const auto coord_begin = Coordinate(idx_begin);
constexpr auto linear_dimensions = TensorDescriptor::GetLinearDimensions();
constexpr auto nonlinear_independent_dimension_groups =
TensorDescriptor::GetNonLinearIndependentDimensionGroups();
constexpr auto lengths = TensorDescriptor::GetLengths();
constexpr auto linear_dimension_lengths = lambda_HackLengths{}(lengths, linear_dimensions);
// run-time component
index_t offset_diff_nonlinear = 0;
template <index_t NGroup>
struct f_recursion
{
template <index_t IGroup>
__host__ __device__ void Run(Number<IGroup>)
{
constexpr auto nonlinear_independent_dimensions_igroup =
nonlinear_independent_dimension_groups.Get(igroup);
constexpr auto nonlinear_independent_lengths_igroup =
lambda_HackLengths{}(lengths, nonlinear_independent_dimensions_igroup);
ford<nonlinear_independent_lengths_igroup>{}(
[&](auto idx_diff_nonlinear_igroup_hack) {
// run-time component
offset_diff_nonlinear +=
coord_begin.GetOffsetDiff(idx_diff_nonlinear_igroup_hack);
Run(Number<IGroup + 1>{});
});
};
// inner-most work
template <>
__host__ __device__ void Run(Number<NGroup>)
{
ford<linear_dimension_lengths>{}([&](auto idx_diff_linear_hack) {
// compile-time component
index_t offset_diff_linear = coord_begin.GetOffsetDiff(idx_diff_linear_hack);
index_t offset =
coord_begin.GetOffset() + offset_diff_nonlinear + offset_diff_linear;
});
}
};
// run-time component
index_t offset_diff_nonlinear = 0;
f_recursion<nonlinear_independent_dimension_groups.GetSize()>{}.Run();
}
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_3D_TENSOR_OP_HPP
#define CK_BLOCKWISE_3D_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace ck {
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
index_t DataPerRead>
struct Blockwise3dTensorCopy1
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
__device__ constexpr Blockwise3dTensorCopy1()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
static_assert(DataPerRead == 1 ||
(SrcDesc{}.GetStride(I2) == 1 && DstDesc{}.GetStride(I2) == 1),
"wrong! only support stride2 == 1 if DataPerRead > 1!\n");
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(SrcDesc{}.GetStride(I1) % DataPerRead == 0 &&
DstDesc{}.GetStride(I1) % DataPerRead == 0,
"src and dst stride1 should be multiple of DataPerRead to keep alignment");
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t read_per_d2 = math::integer_divide_ceil(L2, DataPerRead);
static_assert(read_per_d2 * DataPerRead <= DstDesc{}.GetStride(I1),
"wrong! out-of-bound write will contaminate next line!\n");
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t read_per_d2 = math::integer_divide_ceil(L2, DataPerRead);
constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, L1, read_per_d2>{});
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
auto f_copy = [&](index_t is) {
index_t did[3];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
is -= did[1] * ref_desc.GetStride(I1);
did[2] = is / ref_desc.GetStride(I2);
const index_t src_index =
src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2] * DataPerRead);
const index_t dst_index =
dst_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2] * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index));
};
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
f_copy(is);
}
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
if(has_tail)
{
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
f_copy(is);
}
}
}
};
// starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
class ThreadPerDims,
index_t DataPerRead>
struct Blockwise3dTensorCopy3
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
index_t mSrcMyThreadOffset;
index_t mDstMyThreadOffset;
__device__ Blockwise3dTensorCopy3()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
static_assert(DataPerRead == 1 ||
(SrcDesc{}.GetStride(I2) == 1 && DstDesc{}.GetStride(I2) == 1),
"wrong! only support stride3 == 1 if DataPerRead > 1!\n");
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(
SrcDesc{}.GetStride(I1) % DataPerRead == 0 &&
DstDesc{}.GetStride(I1) % DataPerRead == 0,
"wrong! src and dst stride1 should be multiple of DataPerRead to keep alignment");
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
// we allow out-of-bound read from src in D2 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
static_assert(nloop_d2 * thread_per_d2 * DataPerRead <= DstDesc{}.GetStride(I1),
"wrong! out-of-bound write will contaminate next line!\n");
static_assert(L0 % thread_per_d0 == 0 && L1 % thread_per_d1 == 0,
"wrong! L0, L1, L2 should be divided evenly!\n");
static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2,
"wrrong! BlockSize is not big enough for ThreadPerDims!");
constexpr index_t num_active_thread =
reduce_on_sequence(ThreadPerDims{}, math::multiplies<index_t>{}, Number<1>{});
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor(ThreadPerDims{});
const auto thread_multi_id =
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
mSrcMyThreadOffset = SrcDesc{}.GetOffsetFromMultiIndex(
thread_multi_id[0], thread_multi_id[1], thread_multi_id[2] * DataPerRead);
mDstMyThreadOffset = DstDesc{}.GetOffsetFromMultiIndex(
thread_multi_id[0], thread_multi_id[1], thread_multi_id[2] * DataPerRead);
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
#pragma unroll
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{
#pragma unroll
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
{
#pragma unroll
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{
const index_t src_offset =
SrcDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2 * DataPerRead);
const index_t dst_offset =
DstDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2 * DataPerRead);
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) = *(
reinterpret_cast<const vector_t*>(&p_src[src_offset + mSrcMyThreadOffset]));
}
}
}
}
__device__ static constexpr index_t GetRegisterBufferSize()
{
static_assert(is_same<Float, float>{}, "wrong! only support float!\n");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2;
}
__device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src,
Float* __restrict__ p_clipboard) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
constexpr auto clipboard_desc =
make_ConstantTensorDescriptor(Sequence<nloop_d0, nloop_d1, nloop_d2 * DataPerRead>{});
#pragma unroll
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{
#pragma unroll
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
{
#pragma unroll
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{
const index_t src_offset =
SrcDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2 * DataPerRead);
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
iloop_d0, iloop_d1, iloop_d2 * DataPerRead);
*(reinterpret_cast<vector_t*>(&p_clipboard[clipboard_offset])) = *(
reinterpret_cast<const vector_t*>(&p_src[src_offset + mSrcMyThreadOffset]));
}
}
}
}
__device__ void RunStoreRegisterBuffer(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
constexpr auto clipboard_desc =
make_ConstantTensorDescriptor(Sequence<nloop_d0, nloop_d1, nloop_d2 * DataPerRead>{});
#pragma unroll
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{
#pragma unroll
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
{
#pragma unroll
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
iloop_d0, iloop_d1, iloop_d2 * DataPerRead);
const index_t dst_offset =
DstDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2 * DataPerRead);
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) =
*(reinterpret_cast<const vector_t*>(&p_clipboard[clipboard_offset]));
}
}
}
}
};
} // namespace ck
#endif
......@@ -4,7 +4,6 @@
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "tensor_view.hpp"
#include "tensor_coordinate_deprecated.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
......@@ -484,14 +483,8 @@ struct BlockwiseGenericTensorSliceCopy_v2
address_space_t ThreadBufferAddressSpace = address_space_t::generic>
__device__ void RunLoadThreadBuffer(const TData* p_block_src, TData* p_thread_buffer) const
{
#if 0
mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
#else // tweaking
mThreadwiseLoad.template Run_optimized_address_calculation<TData,
BlockSrcAddressSpace,
ThreadBufferAddressSpace>(
p_block_src, p_thread_buffer);
#endif
mThreadwiseLoad.Run<TData, BlockSrcAddressSpace, ThreadBufferAddressSpace>(p_block_src,
p_thread_buffer);
}
template <typename TData,
......@@ -499,14 +492,8 @@ struct BlockwiseGenericTensorSliceCopy_v2
address_space_t BlockDstAddressSpace = address_space_t::generic>
__device__ void RunStoreThreadBuffer(const TData* p_thread_buffer, TData* p_block_dst) const
{
#if 0
mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
#else // tweaking
mThreadwiseStore.template Run_optimized_address_calculation<TData,
ThreadBufferAddressSpace,
BlockDstAddressSpace>(
p_thread_buffer, p_block_dst);
#endif
mThreadwiseStore.Run<TData, ThreadBufferAddressSpace, BlockDstAddressSpace>(p_thread_buffer,
p_block_dst);
}
template <typename TData,
......@@ -563,130 +550,6 @@ struct BlockwiseGenericTensorSliceCopy_v2
ThreadwiseStore mThreadwiseStore;
};
// this version use TensorView and TensorCoordinate_deprecated
template <index_t BlockSize,
typename SrcTensor,
typename DstTensor,
typename SliceLengths,
typename SubLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorAccessDim,
index_t DstVectorAccessDim,
index_t SrcDataPerAccess,
index_t DstDataPerAccess>
struct BlockwiseGenericTensorSliceCopy_v3
{
static constexpr index_t nDim = SrcTensor::GetNumOfDimension();
using data_type = remove_cv_t<typename SrcTensor::data_type>;
using SrcCoordinate = typename SrcTensor::coordinate_type;
using DstCoordinate = typename DstTensor::coordinate_type;
__device__ constexpr BlockwiseGenericTensorSliceCopy_v3(SrcTensor src_block,
SrcCoordinate src_block_slice_origin,
DstTensor dst_block,
DstCoordinate dst_block_slice_origin)
: mThreadBuffer{make_TensorView(ThreadBufferDesc{}, mpBuffer)}
{
static_assert(
nDim == SrcTensor::GetNumOfDimension() && nDim == DstTensor::GetNumOfDimension() &&
nDim == SliceLengths::GetSize() && nDim == SubLengths::GetSize() &&
nDim == ThreadClusterLengths::GetSize() &&
nDim == ThreadClusterArrangeOrder::GetSize() &&
nDim == SrcDimAccessOrder::GetSize() && nDim == DstDimAccessOrder::GetSize(),
"wrong! nDim not consistent");
static_assert(is_same<SliceLengths, decltype(SubLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(is_same<remove_cv_t<typename SrcTensor::data_type>,
remove_cv_t<typename DstTensor::data_type>>{},
"wrong! type conversion not supported yet");
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
"wrong! BlockSize not consistent with ThreadClusterLengths");
const auto thread_cluster_id =
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
const auto data_cluster_id =
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
mThreadwiseLoad = ThreadwiseLoad(src_block,
src_block_slice_origin + thread_data_id_begin,
mThreadBuffer,
make_zero_array<index_t, nDim>());
mThreadwiseStore = ThreadwiseStore(mThreadBuffer,
make_zero_array<index_t, nDim>(),
dst_block,
dst_block_slice_origin + thread_data_id_begin);
}
__device__ void RunLoadRegisterBuffer() { mThreadwiseLoad.Run(); }
__device__ void RunStoreRegisterBuffer() const { mThreadwiseStore.Run(); }
__device__ void Run()
{
mThreadwiseLoad.Run();
mThreadwiseStore.Run();
}
template <typename T, bool PositiveDirection>
__device__ void
MoveSrcSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection> positive_direction)
{
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction);
}
template <typename T, bool PositiveDirection>
__device__ void
MoveDstSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection> positive_direction)
{
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
}
private:
using ThreadBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
using ThreadBufferTensor = NormalTensorView<ThreadBufferDesc, data_type>;
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v3r1<SrcTensor,
ThreadBufferTensor,
SubLengths,
SrcDimAccessOrder,
SrcDimAccessOrder,
SrcVectorAccessDim,
SrcVectorAccessDim,
SrcDataPerAccess,
1>;
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v3r1<ThreadBufferTensor,
DstTensor,
SubLengths,
DstDimAccessOrder,
DstDimAccessOrder,
DstVectorAccessDim,
DstVectorAccessDim,
1,
DstDataPerAccess>;
data_type mpBuffer[ThreadBufferDesc::GetElementSpace()];
ThreadBufferTensor mThreadBuffer;
ThreadwiseLoad mThreadwiseLoad;
ThreadwiseStore mThreadwiseStore;
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "threadwise_tensor_slice_copy.hpp"
namespace ck {
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SrcLengths,
class SrcSubLengths,
class SrcClusterLengths,
class MapDst2Src,
class MapThreadCluster2SrcCluster,
index_t SrcDataPerRead,
index_t DstDataPerWrite>
struct BlockwiseTensorSliceReorderCopy_v3
{
static constexpr index_t nDim = SrcLengths::GetSize();
index_t mThreadSrcOffset;
index_t mThreadDstOffset;
__device__
BlockwiseTensorSliceReorderCopy_v3(Array<index_t, nDim> src_block_data_multi_id_begin,
Array<index_t, nDim> dst_block_data_multi_id_begin)
{
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto src_lengths = SrcLengths{};
constexpr auto map_dst2src = MapDst2Src{};
constexpr auto src_sub_lengths = SrcSubLengths{};
constexpr auto dst_sub_lengths = src_sub_lengths.ReorderGivenNew2Old(map_dst2src);
constexpr auto map_thread_cluster_2_src_cluster = MapThreadCluster2SrcCluster{};
constexpr auto src_cluster_lengths = SrcClusterLengths{};
constexpr auto thread_cluster_lengths =
src_cluster_lengths.ReorderGivenNew2Old(map_thread_cluster_2_src_cluster);
constexpr auto thread_cluster_desc =
make_ConstantTensorDescriptor_packed(thread_cluster_lengths);
// sanity check: data type
static_assert(is_same<Float, float>{}, "wrong! only support float for now!\n");
// sanity check: nDim
static_assert(SrcDesc::GetNumOfDimension() == nDim &&
DstDesc::GetNumOfDimension() == nDim && SrcLengths::GetSize() == nDim &&
SrcSubLengths::GetSize() == nDim &&
SrcClusterLengths::GetSize() == nDim && MapDst2Src::GetSize() == nDim &&
MapThreadCluster2SrcCluster::GetSize() == nDim,
"wrong! nDim is not consistent\n");
// sanity check: BlockSize
constexpr index_t num_active_thread = thread_cluster_desc.GetElementSize();
static_assert(BlockSize >= num_active_thread,
"wrong! BlockSize is not big enough for ThreadPerDims!");
// sanity check: work division
static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto I = decltype(IDim){};
constexpr index_t src_len = src_lengths.Get(I);
constexpr index_t src_sub_len = src_sub_lengths.Get(I);
constexpr index_t src_cluster_len = src_cluster_lengths.Get(I);
static_assert(src_len % (src_sub_len * src_cluster_len) == 0,
"wrong! cannot evenly divide Src tensor lengths");
});
// sanity check: src read
static_assert(SrcDataPerRead == 1 || SrcDataPerRead == 2 || SrcDataPerRead == 4,
"wrong! only support SrcDataPerRead == 1, 2 or 4!\n");
static_assert(SrcDataPerRead == 1 || src_desc.GetStride(Number<nDim - 1>{}) == 1,
"wrong! only support src.stride(nDim-1) == 1 if SrcDataPerRead > 1!\n");
static_assert(src_sub_lengths.Get(Number<nDim - 1>{}) % SrcDataPerRead == 0,
"wrong! src_sub_lengths[nDim-1] % SrcDataPerRead != 0\n");
static_assert(src_desc.GetStride(Number<nDim - 2>{}) % SrcDataPerRead == 0,
"wrong! should satisfy src_desc.stride(nDim-2) % SrcDataPerRead == 0, to "
"keep alignment");
// sanity check: dst write
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
"wrong! only support DstDataPerWrite == 1, 2 or 4!\n");
static_assert(DstDataPerWrite == 1 || dst_desc.GetStride(Number<nDim - 1>{}) == 1,
"wrong! only support dst.stride(nDim-1) == 1 if DstDataPerWrite > 1!\n");
static_assert(dst_sub_lengths.Get(Number<nDim - 1>{}) % DstDataPerWrite == 0,
"wrong! dst_sub_lengths[nDim-1] % DstDataPerWrite != 0\n");
static_assert(dst_desc.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
"wrong! should satisfy dst_desc.stride(nDim-2) % DstDataPerWrite == 0, to "
"keep alignment");
// start dividing work
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
const auto thread_multi_id =
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
// compiler: thread_multi_id, src_data_multi_id, dst_data_multi_id, will use separate
// regsiters, or only one copy???
auto src_data_multi_id =
reorder_array_given_old2new(thread_multi_id, map_thread_cluster_2_src_cluster);
static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr index_t idim = IDim;
// compiler: will it really compute index here, or be merged with
// GetOffsetFromMultiIndex and
// optimized away???
src_data_multi_id(idim) *= src_sub_lengths.Get(IDim);
});
// compiler: will it really compute index here, or be merged with GetOffsetFromMultiIndex
// and
// optimized away???
const auto dst_data_multi_id = reorder_array_given_new2old(src_data_multi_id, map_dst2src);
mThreadSrcOffset =
src_desc.GetOffsetFromMultiIndex(src_data_multi_id + src_block_data_multi_id_begin);
mThreadDstOffset =
dst_desc.GetOffsetFromMultiIndex(dst_data_multi_id + dst_block_data_multi_id_begin);
#if 0
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(thread_cluster_desc, "thread_cluster_desc: ");
}
if(get_block_1d_id() == 0)
{
printf("id %5u %5u: "
"thread_multi_id: %u %u, "
"src_block_data_multi_id_begin: %u %u, "
"src_data_multi_id: %u %u, "
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
get_block_1d_id(),
get_thread_local_1d_id(),
thread_multi_id[0],
thread_multi_id[1],
src_block_data_multi_id_begin[0],
src_block_data_multi_id_begin[1],
src_data_multi_id[0],
src_data_multi_id[1],
mThreadSrcOffset,
mThreadDstOffset);
}
#endif
}
__device__ static constexpr index_t GetRegisterBufferSize()
{
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
constexpr auto src_data_per_cluster_per_dims =
thread_sub_tensor_lengths * SrcClusterLengths{};
constexpr auto repeat_lengths = transform_sequences(
math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
constexpr auto thread_tensor_desc =
make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
return thread_tensor_desc.GetElementSpace();
}
__device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src,
Float* __restrict__ p_clipboard) const
{
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
constexpr auto src_data_per_cluster_per_dims =
thread_sub_tensor_lengths * SrcClusterLengths{};
constexpr auto repeat_lengths = transform_sequences(
math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
constexpr auto thread_tensor_desc =
make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
constexpr auto src_data_multi_id = repeat_multi_id * src_data_per_cluster_per_dims;
constexpr auto clipboard_data_multi_id = repeat_multi_id * thread_sub_tensor_lengths;
constexpr index_t src_offset = SrcDesc{}.GetOffsetFromMultiIndex(src_data_multi_id);
constexpr index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id);
threadwise_tensor_slice_copy(SrcDesc{},
p_src + src_offset + mThreadSrcOffset,
thread_tensor_desc,
p_clipboard + clipboard_offset,
thread_sub_tensor_lengths,
Number<SrcDataPerRead>{});
});
}
__device__ void RunStoreRegisterBuffer(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const
{
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
constexpr auto src_data_per_cluster_per_dims =
thread_sub_tensor_lengths * SrcClusterLengths{};
constexpr auto repeat_lengths = transform_sequences(
math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
constexpr auto thread_tensor_desc =
make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
constexpr auto clipboard_data_multi_id = repeat_multi_id * thread_sub_tensor_lengths;
constexpr auto src_data_multi_id = repeat_multi_id * src_data_per_cluster_per_dims;
// reorder src_data_multi_id to get dst_data_multi_id
constexpr auto dst_data_multi_id = src_data_multi_id.ReorderGivenNew2Old(MapDst2Src{});
constexpr index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id);
constexpr index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id);
// write in the order of dst
#if 1
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(thread_tensor_desc,
p_clipboard + clipboard_offset,
DstDesc{},
p_dst + dst_offset +
mThreadDstOffset,
thread_sub_tensor_lengths,
MapDst2Src{});
#else
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(thread_tensor_desc,
p_clipboard + clipboard_offset,
DstDesc{},
p_dst + dst_offset +
mThreadDstOffset,
thread_sub_tensor_lengths,
MapDst2Src{},
Number<DstDataPerWrite>{});
#endif
});
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
Float p_clipboard[GetRegisterBufferSize()];
RunLoadRegisterBuffer(p_src, p_clipboard);
RunStoreRegisterBuffer(p_clipboard, p_dst);
}
// this function doesn't do santiy check on whether the slicing window is out of the boundary
// of the tensor being sliced
template <index_t IDim_, index_t StepSize, bool PositiveDirection>
__device__ void MoveSlicingWindowOnSourceTensor(
Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
{
constexpr auto IDim = Number<IDim_>{};
static_if<PositiveDirection>{}([&](auto fwd) {
mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
}).Else([&](auto fwd) { mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim); });
}
};
} // namespace ck
#endif
......@@ -48,36 +48,6 @@ struct type_convert
}
};
template <class T>
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1)
{
d += s0 * s1;
}
#if 0
__device__ void fused_multiply_accumulate(half& d, const half& s0, const half& s1) { d += s0 * s1; }
__device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2& s1)
{
d += s0.x * s1.x;
d += s0.y * s1.y;
}
__device__ void fused_multiply_accumulate(float& d, const half2& s0, const half2& s1)
{
d += s0.x * s1.x + s0.y * s1.y;
}
__device__ void fused_multiply_accumulate(char& d, const char& s0, const char& s1) { d += s0 * s1; }
// TODO:: this interface is misleading, s0, s1 are actually int8x4
// need to make a better interface
__device__ void fused_multiply_accumulate(int32_t& d, const int32_t& s0, const int32_t& s1)
{
d = __dp4a(s0, s1, d);
}
#endif
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment