Unverified Commit fbdf4332 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Add xdlops v4r4r4 into online compilation (#48)



* init for v4r4 xdlops olc

* refactor wrap

* init impl of v4r4 nchw xdlops olc

* tuning

* test perf

* fixed v4r4 nhwc

* tuned v4r4 nhwc

* use gridwise_gemm_xdlops_v2r3

* swap a/b

* add pointer support into offline v2r3

* debugging v4r4r4 transform for olc

* change timer of olc

* refactor v4r4 xdlops nchw olc

* remove transform fun in v4r4 xdlops nhwc olc
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 0a72e4df
#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "driver_dynamic_gemm_xdlops_v1.hpp"
#include "driver_dynamic_gemm_xdlops_v2.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename FloatAB,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmKPack,
typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmKPack;
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_global_desc =
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0));
assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1));
assert(GemmK0 == in_gemmk0_gemmn_gemmk1_global_desc.GetLength(I0));
assert(GemmK0 == wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0));
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
// out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor
constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr auto out_m0_m1_m2_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc,
in_gemmk0_gemmn_gemmk1_global_desc,
out_m0_m1_m2_n_global_desc,
out_gemm_block_cluster_desc,
wei_gemmk0_gemmm_gemmk1_global_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_iterator_hacks,
out_m0_m1_m2_n_global_iterator_hacks,
wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks);
}
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename FloatAB,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmKPack,
typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmKPack;
assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0);
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0));
assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1));
assert(GemmK0 == in_gemmk0_gemmn_gemmk1_global_desc.GetLength(I0));
assert(GemmK0 == wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0));
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
// out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor
constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 1, 2, 0, 0>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr auto out_m0_m1_m2_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc,
in_gemmk0_gemmn_gemmk1_global_desc,
out_m0_m1_m2_n_global_desc,
out_gemm_block_cluster_desc,
wei_gemmk0_gemmm_gemmk1_global_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_iterator_hacks,
out_m0_m1_m2_n_global_iterator_hacks,
wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks);
}
} // namespace ck
#endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V1
#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V1
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t KPack,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
const FloatAB* p_b_global,
FloatC* p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
const auto K = a_k_m_global_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
if(!(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// GEMM
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGlobalDesc,
BGlobalDesc,
CGlobalDesc,
CBlockClusterDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
KPack,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks>;
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
std::cerr << "has_main_k_block_loop = " << has_main_k_block_loop
<< " has_double_tail_k_block_loop = " << has_double_tail_k_block_loop << std::endl;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else
{
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else
{
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
return ave_time;
#endif
}
} // namespace ck
#endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2
#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops_v2.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t KPack,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_xdlops_v2(const FloatAB* p_a_global,
const FloatAB* p_b_global,
FloatC* p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
const auto K = a_k_m_global_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
if(!(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// GEMM
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGlobalDesc,
BGlobalDesc,
CGlobalDesc,
CBlockClusterDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
KPack,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks>;
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0;
const auto kernel = kernel_dynamic_gemm_xdlops_v2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0;
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
return ave_time;
#endif
}
} // namespace ck
#endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R2
#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R2
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops_v2r2.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
__host__ float driver_dynamic_gemm_xdlops_v2r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
using GridwiseGemm =
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r2<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
{
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{
throw std::runtime_error(
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r2 has invalid setting");
}
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const auto kernel = kernel_dynamic_gemm_xdlops_v2r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0M1M2NGridDesc>,
remove_reference_t<CBlockClusterAdaptor>>;
float ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor);
return ave_time;
}
} // namespace ck
#endif
...@@ -21,6 +21,7 @@ template <index_t BlockSize, ...@@ -21,6 +21,7 @@ template <index_t BlockSize,
index_t KPerBlock, index_t KPerBlock,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t K1,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1, typename ABlockTransferThreadSliceLengths_K0_M_K1,
...@@ -83,6 +84,7 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -83,6 +84,7 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
KPerBlock, KPerBlock,
MPerWave, MPerWave,
NPerWave, NPerWave,
K1,
MRepeat, MRepeat,
NRepeat, NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
...@@ -148,6 +150,7 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -148,6 +150,7 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t<CM0M1M2NGridDesc>, remove_reference_t<CM0M1M2NGridDesc>,
remove_reference_t<CBlockClusterAdaptor>>; remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel, float ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
...@@ -162,6 +165,32 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -162,6 +165,32 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
#endif
return ave_time; return ave_time;
} }
......
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
...@@ -13,7 +13,7 @@ template <index_t BlockSize, ...@@ -13,7 +13,7 @@ template <index_t BlockSize,
class BBlockDesc, class BBlockDesc,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPack> index_t K1>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{ {
...@@ -32,7 +32,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -32,7 +32,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
static constexpr index_t MWaves = M1 / MPerWave; static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave; static constexpr index_t NWaves = N1 / NPerWave;
...@@ -119,18 +119,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -119,18 +119,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
"wrong! KPack dimension not consistent"); "wrong! K1 dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(KPack == BBlockDesc{}.GetLength(I3), "KPack is wrong!"); static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(KPack % xdlops_gemm.mfma_type.k_base == 0, "KPack is wrong!"); static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
...@@ -194,11 +194,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -194,11 +194,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1, Number<MRepeat>{}, I1, Number<KPack>{})); make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1, Number<NRepeat>{}, I1, Number<KPack>{})); make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
...@@ -207,20 +207,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -207,20 +207,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
FloatAB, FloatAB,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, MRepeat, 1, KPack>, Sequence<1, MRepeat, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
KPack, K1,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, NRepeat, 1, KPack>, Sequence<1, NRepeat, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
KPack, K1,
1>; 1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
...@@ -233,7 +233,7 @@ template <index_t BlockSize, ...@@ -233,7 +233,7 @@ template <index_t BlockSize,
class BBlockDesc, class BBlockDesc,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPack> index_t K1>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
{ {
...@@ -244,7 +244,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -244,7 +244,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, K1>{};
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
...@@ -339,18 +339,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -339,18 +339,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
"wrong! KPack dimension not consistent"); "wrong! K1 dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(KPack == BBlockDesc{}.GetLength(I3), "KPack is wrong!"); static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(KPack % xdlops_gemm.mfma_type.k_base == 0, "KPack is wrong!"); static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
...@@ -491,11 +491,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -491,11 +491,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1, Number<MRepeat>{}, I1, Number<KPack>{})); make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1, Number<NRepeat>{}, I1, Number<KPack>{})); make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
...@@ -504,20 +504,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -504,20 +504,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
FloatAB, FloatAB,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>, Sequence<1, 1, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
1, // KPack, 1, // K1,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>, Sequence<1, 1, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
1, // KPack, 1, // K1,
1>; 1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
......
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k0_m_k1_global_desc,
const BGlobalDesc b_k0_n_k1_global_desc,
const CGlobalDesc c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc c_block_cluster_desc)
{
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k0_m_k1_global_desc,
const void __CONSTANT__* p_b_k0_n_k1_global_desc,
const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k0_m_k1_global_desc =
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k0_m_k1_global_desc);
const auto b_k0_n_k1_global_desc =
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k0_n_k1_global_desc);
const auto c_m0_m1_m2_n_global_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_m2_n_global_desc);
const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t KPack,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M_KPack,
typename ABlockTransferThreadClusterLengths_K_M_KPack,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_KPack,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N_KPack,
typename BBlockTransferThreadClusterLengths_K_N_KPack,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_KPack,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = Number<KPack>{};
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k0_m_k1_global_desc,
const BGlobalDesc& b_k0_n_k1_global_desc,
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_k0_m_k1_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k0_n_k1_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_global_desc.GetLength(I0);
const auto M = a_k0_m_k1_global_desc.GetLength(I1);
const auto N = b_k0_n_k1_global_desc.GetLength(I1);
const auto K1 = b_k0_n_k1_global_desc.GetLength(I2);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_global into SGPR
const index_t m_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = Number<KPack>{};
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock, KPack>,
ABlockTransferThreadSliceLengths_K_M_KPack,
ABlockTransferThreadClusterLengths_K_M_KPack,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m_k1_global_desc),
decltype(a_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_KPack,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k0_m_k1_global_desc,
make_multi_index(0, m_block_data_idx_on_global, 0),
a_k0_m_k1_block_desc,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock, KPack>,
BBlockTransferThreadSliceLengths_K_N_KPack,
BBlockTransferThreadClusterLengths_K_N_KPack,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n_k1_global_desc),
decltype(b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_KPack,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k0_n_k1_global_desc,
make_multi_index(0, n_block_data_idx_on_global, 0),
b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc),
MPerWave,
NPerWave,
KPack>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<NumXdlops>{}));
constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<NumBlks>{}, Number<BlkSize>{}));
StaticBuffer<AddressSpace::Vgpr,
vector_type<FloatAcc, c_blk_nb_bs_desc.GetElementSpaceSize()>,
c_mr_nr_nx_desc.GetElementSpaceSize()>
c_thread_buf;
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
// auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
// c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
// ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
// decltype(c_m0_m1_n0_n1_thread_desc),
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k0_n_k1_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_k0_n_k1_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_k0_n_k1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double + a_block_space_size, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double + b_block_space_size, b_k0_n_k1_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(
a_k0_m_k1_global_desc,
a_block_slice_copy_step,
a_k0_m_k1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(
b_k0_n_k1_global_desc,
b_block_slice_copy_step,
b_k0_n_k1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
asm volatile("s_nop 0");
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(
a_k0_m_k1_global_desc,
a_block_slice_copy_step,
a_k0_m_k1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(
b_k0_n_k1_global_desc,
b_block_slice_copy_step,
b_k0_n_k1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
asm volatile("s_nop 0");
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K0 - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_global_desc,
a_block_slice_copy_step,
a_k0_m_k1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_global_desc,
b_block_slice_copy_step,
b_k0_n_k1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) {
static_for<0, NumBlks, 1>{}([&](auto blk_i) {
auto c_blk = c_thread_buf[Number<c_mr_nr_nx_desc.CalculateOffset(
make_tuple(mr_i, nr_i, xdlops_i))>{}];
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(j) = c_blk.template AsType<FloatAcc>()[Number<
c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(
mr_i, nr_i, xdlops_i, blk_i);
const index_t m_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks =
CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
Sequence<M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_global_desc,
make_multi_index(m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_global_desc,
c_global_buf,
c_m0_m1_m2_n_global_tensor_iterator_hacks);
});
});
});
});
}
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k0_m_k1_global_desc,
const BGlobalDesc& b_k0_n_k1_global_desc,
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(p_a_global,
p_b_global,
p_c_global,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_xdlops_v2(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k0_m_k1_global_desc,
const BGlobalDesc b_k0_n_k1_global_desc,
const CGlobalDesc c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc c_block_cluster_desc)
{
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_xdlops_v2(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k0_m_k1_global_desc,
const void __CONSTANT__* p_b_k0_n_k1_global_desc,
const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k0_m_k1_global_desc =
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k0_m_k1_global_desc);
const auto b_k0_n_k1_global_desc =
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k0_n_k1_global_desc);
const auto c_m0_m1_m2_n_global_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_m2_n_global_desc);
const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t KPack,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M_KPack,
typename ABlockTransferThreadClusterLengths_K_M_KPack,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_KPack,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N_KPack,
typename BBlockTransferThreadClusterLengths_K_N_KPack,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_KPack,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = Number<KPack>{};
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k0_m_k1_global_desc,
const BGlobalDesc& b_k0_n_k1_global_desc,
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_k0_m_k1_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k0_n_k1_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_global_desc.GetLength(I0);
const auto M = a_k0_m_k1_global_desc.GetLength(I1);
const auto N = b_k0_n_k1_global_desc.GetLength(I1);
const auto K1 = b_k0_n_k1_global_desc.GetLength(I2);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_global into SGPR
const index_t m_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = Number<KPack>{};
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock, KPack>,
ABlockTransferThreadSliceLengths_K_M_KPack,
ABlockTransferThreadClusterLengths_K_M_KPack,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m_k1_global_desc),
decltype(a_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_KPack,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k0_m_k1_global_desc,
make_multi_index(0, m_block_data_idx_on_global, 0),
a_k0_m_k1_block_desc,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock, KPack>,
BBlockTransferThreadSliceLengths_K_N_KPack,
BBlockTransferThreadClusterLengths_K_N_KPack,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n_k1_global_desc),
decltype(b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_KPack,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k0_n_k1_global_desc,
make_multi_index(0, n_block_data_idx_on_global, 0),
b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc),
MPerWave,
NPerWave,
KPack>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<NumXdlops>{}));
constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<NumBlks>{}, Number<BlkSize>{}));
StaticBuffer<AddressSpace::Vgpr,
vector_type<FloatAcc, c_blk_nb_bs_desc.GetElementSpaceSize()>,
c_mr_nr_nx_desc.GetElementSpaceSize()>
c_thread_buf;
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
// register allocation for output
// auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
// c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
// ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
// decltype(c_m0_m1_n0_n1_thread_desc),
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k0_n_k1_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_k0_n_k1_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
// preload data into LDS
{
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
}
// main body
index_t k_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_global_desc,
a_block_slice_copy_step,
a_k0_m_k1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_global_desc,
b_block_slice_copy_step,
b_k0_n_k1_global_move_slice_window_iterator_hack);
a_blockwise_copy.RunRead(
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
k_block_data_begin += KPerBlock;
} while(k_block_data_begin < (K0 - KPerBlock));
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) {
static_for<0, NumBlks, 1>{}([&](auto blk_i) {
auto c_blk = c_thread_buf[Number<c_mr_nr_nx_desc.CalculateOffset(
make_tuple(mr_i, nr_i, xdlops_i))>{}];
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(j) = c_blk.template AsType<FloatAcc>()[Number<
c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(
mr_i, nr_i, xdlops_i, blk_i);
const index_t m_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks =
CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
Sequence<M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_global_desc,
make_multi_index(m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_global_desc,
c_global_buf,
c_m0_m1_m2_n_global_tensor_iterator_hacks);
});
});
});
});
}
}
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k0_m_k1_global_desc,
const BGlobalDesc& b_k0_n_k1_global_desc,
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc)
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(p_a_global,
p_b_global,
p_c_global,
a_k0_m_k1_global_desc,
b_k0_n_k1_global_desc,
c_m0_m1_m2_n_global_desc,
c_block_cluster_desc,
p_shared_block);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R2_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R2_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CM0M1M2NGridDesc,
typename CBlockClusterAdaptor>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_xdlops_v2r2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor);
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc)
{
// TODO: turn on this
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) &&
(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0);
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr auto
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1.value>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr auto M0 = Number<CLayout.M1()>{};
constexpr auto M1 = Number<CLayout.N1()>{};
constexpr auto M2 = Number<CLayout.M0()>{};
const auto c_m0_m1_m2_n_grid_desc = transform_dynamic_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M / (M1 * M2), M1, M2)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
return c_m0_m1_m2_n_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock, K1.value>,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m_k1_grid_desc),
decltype(a_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k0_m_k1_grid_desc,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_k0_m_k1_block_desc,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock, K1.value>,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n_k1_grid_desc),
decltype(b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k0_n_k1_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc),
MPerWave,
NPerWave,
K1.value>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<NumXdlops>{}));
constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<NumBlks>{}, Number<BlkSize>{}));
StaticBuffer<AddressSpace::Vgpr,
vector_type<FloatAcc, c_blk_nb_bs_desc.GetElementSpaceSize()>,
c_mr_nr_nx_desc.GetElementSpaceSize()>
c_thread_buf;
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{};
constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack =
AGridMoveSliceWindowIteratorHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack =
BGridMoveSliceWindowIteratorHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
// preload data into LDS
{
a_blockwise_copy.RunRead(
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
b_blockwise_copy.RunRead(
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
}
// main body
index_t k_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_iterator_hack);
a_blockwise_copy.RunRead(
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
block_sync_lds();
b_blockwise_copy.RunRead(
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
k_block_data_begin += KPerBlock;
} while(k_block_data_begin < (K0 - KPerBlock));
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) {
static_for<0, NumBlks, 1>{}([&](auto blk_i) {
auto c_blk = c_thread_buf[Number<c_mr_nr_nx_desc.CalculateOffset(
make_tuple(mr_i, nr_i, xdlops_i))>{}];
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(j) = c_blk.template AsType<FloatAcc>()[Number<
c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(
mr_i, nr_i, xdlops_i, blk_i);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks =
CGridIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc),
Sequence<M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_grid_desc,
make_multi_index(m_thread_data_on_grid / (M2 * M1),
m_thread_data_on_grid % (M2 * M1) / M2,
m_thread_data_on_grid % M2,
n_thread_data_on_grid)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
});
});
});
});
}
}
};
} // namespace ck
#endif
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
namespace ck { namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -45,6 +46,50 @@ __global__ void ...@@ -45,6 +46,50 @@ __global__ void
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CM0M1M2NGridDesc,
typename CBlockClusterAdaptor>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k0_m_k1_grid_desc,
const void __CONSTANT__* p_b_k0_n_k1_grid_desc,
const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc,
const void __CONSTANT__* p_c_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
const auto b_k0_n_k1_grid_desc =
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto c_block_cluster_adaptor =
*reinterpret_cast<const CBlockClusterAdaptor*>((const void*)p_c_block_cluster_adaptor);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor);
}
#endif
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -59,6 +104,7 @@ template <index_t BlockSize, ...@@ -59,6 +104,7 @@ template <index_t BlockSize,
index_t KPerBlock, index_t KPerBlock,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t K1Value,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1, typename ABlockTransferThreadSliceLengths_K0_M_K1,
...@@ -94,7 +140,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -94,7 +140,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -160,7 +206,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -160,7 +206,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const auto M = c_m_n_grid_desc.GetLength(I0); const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1); const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1.value>{}; constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout(); constexpr auto CLayout = xdlops_gemm.GetCLayout();
...@@ -267,7 +313,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -267,7 +313,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock, K1.value>, Sequence<KPerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -294,7 +340,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -294,7 +340,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock, K1.value>, Sequence<KPerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
...@@ -354,7 +400,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -354,7 +400,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
decltype(b_k0_n0_n1_k1_block_desc), decltype(b_k0_n0_n1_k1_block_desc),
MPerWave, MPerWave,
NPerWave, NPerWave,
K1.value>{}; K1>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout(); constexpr auto CLayout = blockwise_gemm.GetCLayout();
......
...@@ -39,7 +39,6 @@ ...@@ -39,7 +39,6 @@
#if CK_USE_AMD_XDLOPS #if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp" #include "amd_xdlops.hpp"
#include "amd_xdlops_inline_asm.hpp"
#endif #endif
#endif #endif
#include "common_header.hpp"
#include "type_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
using namespace ck;
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type;
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type;
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_CONV_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize;
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
constexpr index_t MPerWave = CK_PARAM_MPerWave;
constexpr index_t NPerWave = CK_PARAM_NPerWave;
constexpr index_t MRepeat = CK_PARAM_MRepeat;
constexpr index_t NRepeat = CK_PARAM_NRepeat;
constexpr index_t K1 = CK_PARAM_K1;
using ABlockTransferThreadSliceLengths_K0_M_K1 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
using ABlockTransferThreadClusterLengths_K0_M_K1 =
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
using ABlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
using BBlockTransferThreadSliceLengths_K0_N_K1 =
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
using BBlockTransferThreadClusterLengths_K0_N_K1 =
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
using BBlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
extern "C" __global__ void
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare(
int n,
int c,
int hi,
int wi,
int k,
int y,
int x,
int convStrideH,
int convStrideW,
int convDilationY,
int convDilationX,
int leftPadH,
int leftPadW,
int rightPadH,
int rightPadW,
void* p_a_k0_m_k1_grid_desc,
void* p_b_k0_n_k1_grid_desc,
void* p_c_m0_m1_m2_n_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi));
const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x));
const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo));
const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(convStrideH, convStrideW),
make_tuple(convDilationY, convDilationX),
make_tuple(leftPadH, leftPadW),
make_tuple(rightPadH, rightPadW),
Number<K1>{});
const auto a_k0_m_k1_grid_desc = descs[I0];
const auto b_k0_n_k1_grid_desc = descs[I1];
const auto c_m_n_grid_desc = descs[I2];
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using AGridIteratorHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using BGridIteratorHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm =
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
false>;
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0)
{
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
a_k0_m_k1_grid_desc;
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
b_k0_n_k1_grid_desc;
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
c_m0_m1_m2_n_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
}
};
extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k0_m_k1_grid_desc,
const void __CONSTANT__* p_b_k0_n_k1_grid_desc,
const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc,
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
constexpr auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3));
constexpr auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
constexpr auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
Number<K1>{});
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
constexpr auto c_m_n_grid_desc = descs[I2];
using AGridIteratorHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using BGridIteratorHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using GridwiseGemm =
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
false>;
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
const auto b_k0_n_k1_grid_desc =
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor);
};
#include "common_header.hpp"
#include "type_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
using namespace ck;
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type;
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type;
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_CONV_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize;
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
constexpr index_t MPerWave = CK_PARAM_MPerWave;
constexpr index_t NPerWave = CK_PARAM_NPerWave;
constexpr index_t MRepeat = CK_PARAM_MRepeat;
constexpr index_t NRepeat = CK_PARAM_NRepeat;
constexpr index_t K1 = CK_PARAM_K1;
using ABlockTransferThreadSliceLengths_K0_M_K1 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
using ABlockTransferThreadClusterLengths_K0_M_K1 =
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
using ABlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
using BBlockTransferThreadSliceLengths_K0_N_K1 =
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
using BBlockTransferThreadClusterLengths_K0_N_K1 =
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
using BBlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
extern "C" __global__ void
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
int n,
int hi,
int wi,
int c,
int k,
int y,
int x,
int convStrideH,
int convStrideW,
int convDilationY,
int convDilationX,
int leftPadH,
int leftPadW,
int rightPadH,
int rightPadW,
void* p_a_k0_m_k1_grid_desc,
void* p_b_k0_n_k1_grid_desc,
void* p_c_m0_m1_m2_n_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
const auto in_n_hi_wi_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, hi, wi, c));
const auto wei_k_y_x_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, y, x, c));
const auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, ho, wo, k));
const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
make_tuple(convStrideH, convStrideW),
make_tuple(convDilationY, convDilationX),
make_tuple(leftPadH, leftPadW),
make_tuple(rightPadH, rightPadW),
Number<K1>{});
const auto a_k0_m_k1_grid_desc = descs[I0];
const auto b_k0_n_k1_grid_desc = descs[I1];
const auto c_m_n_grid_desc = descs[I2];
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using BGridIteratorHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using AGridIteratorHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
using GridwiseGemm =
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
false>;
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0)
{
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
a_k0_m_k1_grid_desc;
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
b_k0_n_k1_grid_desc;
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
c_m0_m1_m2_n_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
}
};
extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k0_m_k1_grid_desc,
const void __CONSTANT__* p_b_k0_n_k1_grid_desc,
const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc,
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_hi_wi_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256));
constexpr auto wei_k_y_x_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 3, 3, 256));
constexpr auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256));
constexpr auto descs =
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
Number<K1>{});
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
constexpr auto c_m_n_grid_desc = descs[I2];
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using BGridIteratorHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using AGridIteratorHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
using GridwiseGemm =
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
false>;
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
const auto b_k0_n_k1_grid_desc =
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor);
};
...@@ -92,6 +92,8 @@ set(MCONV_KERNEL_INCLUDES ...@@ -92,6 +92,8 @@ set(MCONV_KERNEL_INCLUDES
set(MCONV_KERNELS set(MCONV_KERNELS
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp ../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp ../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
) )
add_kernels("olCompiling/" "${MCONV_KERNELS}") add_kernels("olCompiling/" "${MCONV_KERNELS}")
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
#define USE_CONV_BWD_V4R1_XDL_NHWC 0 #define USE_CONV_BWD_V4R1_XDL_NHWC 1
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 #define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
enum ConvBackwardDataAlgo enum ConvBackwardDataAlgo
......
...@@ -19,8 +19,6 @@ ...@@ -19,8 +19,6 @@
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
...@@ -30,9 +28,7 @@ ...@@ -30,9 +28,7 @@
#define USE_CONV_FWD_V4R5_NCHW 0 #define USE_CONV_FWD_V4R5_NCHW 0
#define USE_CONV_FWD_V4R5R2_NCHW 0 #define USE_CONV_FWD_V4R5R2_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4_XDL_NCHW 1 #define USE_CONV_FWD_V4R4R2_XDL_NCHW 1
#define USE_CONV_FWD_V4R4R2_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R3_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
enum ConvForwardAlgo enum ConvForwardAlgo
...@@ -43,10 +39,8 @@ enum ConvForwardAlgo ...@@ -43,10 +39,8 @@ enum ConvForwardAlgo
V4R5NCHW, // 3 V4R5NCHW, // 3
V4R5R2NCHW, // 4 V4R5R2NCHW, // 4
V5R1NCHW, // 5 V5R1NCHW, // 5
V4R4XDLNCHW, // 6 V4R4R2XDLNCHW, // 6
V4R4R2XDLNHWC, // 7 V4R4R4XDLNHWC // 7
V4R4R3XDLNHWC, // 8
V4R4R4XDLNHWC // 9
}; };
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -462,8 +456,8 @@ int main(int argc, char* argv[]) ...@@ -462,8 +456,8 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_FWD_V4R4_XDL_NCHW #if USE_CONV_FWD_V4R4R2_XDL_NCHW
if(algo == ConvForwardAlgo::V4R4XDLNCHW) if(algo == ConvForwardAlgo::V4R4R2XDLNCHW)
{ {
if(layout != ConvTensorLayout::NCHW) if(layout != ConvTensorLayout::NCHW)
{ {
...@@ -489,60 +483,6 @@ int main(int argc, char* argv[]) ...@@ -489,60 +483,6 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_FWD_V4R4R2_XDL_NHWC
if(algo == ConvForwardAlgo::V4R4R2XDLNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V4R4R3_XDL_NHWC
if(algo == ConvForwardAlgo::V4R4R3XDLNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V4R4R4_XDL_NHWC #if USE_CONV_FWD_V4R4R4_XDL_NHWC
if(algo == ConvForwardAlgo::V4R4R4XDLNHWC) if(algo == ConvForwardAlgo::V4R4R4XDLNHWC)
{ {
......
...@@ -12,11 +12,17 @@ ...@@ -12,11 +12,17 @@
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv.hpp" #include "host_conv.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp" #include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_CONV_FWD_V4R4_NCHW 1 #define USE_CONV_FWD_V4R4_NCHW 1
#define USE_CONV_FWD_V4R5_NCHW 1 #define USE_CONV_FWD_V4R5_NCHW 1
#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1
#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1
#include "conv_tunables.hpp" #include "conv_tunables.hpp"
#include "handle.hpp" #include "handle.hpp"
...@@ -24,10 +30,10 @@ ...@@ -24,10 +30,10 @@
enum ConvForwardAlgo enum ConvForwardAlgo
{ {
V4R4NCHW, V4R4NCHW, // 0
V4R4NHWC, V4R5NCHW, // 1
V4R5NCHW, V4R4XDLNCHW, // 2
V5R1NCHW V4R4XDLNHWC // 3
}; };
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -105,14 +111,16 @@ int main(int argc, char* argv[]) ...@@ -105,14 +111,16 @@ int main(int argc, char* argv[])
{ {
case ConvTensorLayout::NCHW: case ConvTensorLayout::NCHW:
// NCHW // NCHW
in_lengths_host[0] = static_cast<std::size_t>(N); in_lengths_host[0] = static_cast<std::size_t>(N);
in_lengths_host[1] = static_cast<std::size_t>(C); in_lengths_host[1] = static_cast<std::size_t>(C);
in_lengths_host[2] = static_cast<std::size_t>(Hi); in_lengths_host[2] = static_cast<std::size_t>(Hi);
in_lengths_host[3] = static_cast<std::size_t>(Wi); in_lengths_host[3] = static_cast<std::size_t>(Wi);
wei_lengths_host[0] = static_cast<std::size_t>(K); wei_lengths_host[0] = static_cast<std::size_t>(K);
wei_lengths_host[1] = static_cast<std::size_t>(C); wei_lengths_host[1] = static_cast<std::size_t>(C);
wei_lengths_host[2] = static_cast<std::size_t>(Y); wei_lengths_host[2] = static_cast<std::size_t>(Y);
wei_lengths_host[3] = static_cast<std::size_t>(X); wei_lengths_host[3] = static_cast<std::size_t>(X);
out_lengths_host[0] = static_cast<std::size_t>(N); out_lengths_host[0] = static_cast<std::size_t>(N);
out_lengths_host[1] = static_cast<std::size_t>(K); out_lengths_host[1] = static_cast<std::size_t>(K);
out_lengths_host[2] = static_cast<std::size_t>(Ho); out_lengths_host[2] = static_cast<std::size_t>(Ho);
...@@ -120,14 +128,16 @@ int main(int argc, char* argv[]) ...@@ -120,14 +128,16 @@ int main(int argc, char* argv[])
break; break;
case ConvTensorLayout::NHWC: case ConvTensorLayout::NHWC:
// NHWC // NHWC
in_lengths_host[0] = static_cast<std::size_t>(N); in_lengths_host[0] = static_cast<std::size_t>(N);
in_lengths_host[1] = static_cast<std::size_t>(Hi); in_lengths_host[1] = static_cast<std::size_t>(Hi);
in_lengths_host[2] = static_cast<std::size_t>(Wi); in_lengths_host[2] = static_cast<std::size_t>(Wi);
in_lengths_host[3] = static_cast<std::size_t>(C); in_lengths_host[3] = static_cast<std::size_t>(C);
wei_lengths_host[0] = static_cast<std::size_t>(K); wei_lengths_host[0] = static_cast<std::size_t>(K);
wei_lengths_host[1] = static_cast<std::size_t>(Y); wei_lengths_host[1] = static_cast<std::size_t>(Y);
wei_lengths_host[2] = static_cast<std::size_t>(X); wei_lengths_host[2] = static_cast<std::size_t>(X);
wei_lengths_host[3] = static_cast<std::size_t>(C); wei_lengths_host[3] = static_cast<std::size_t>(C);
out_lengths_host[0] = static_cast<std::size_t>(N); out_lengths_host[0] = static_cast<std::size_t>(N);
out_lengths_host[1] = static_cast<std::size_t>(Ho); out_lengths_host[1] = static_cast<std::size_t>(Ho);
out_lengths_host[2] = static_cast<std::size_t>(Wo); out_lengths_host[2] = static_cast<std::size_t>(Wo);
...@@ -271,10 +281,80 @@ int main(int argc, char* argv[]) ...@@ -271,10 +281,80 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_FWD_V4R4_XDLOPS_NCHW
if(algo == ConvForwardAlgo::V4R4XDLNCHW)
{
if(layout != ConvTensorLayout::NCHW)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nchw();
tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable =
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw;
device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_olc<in_data_t,
acc_data_t,
out_data_t>(
handle,
tmp[I0],
tmp[I1],
tmp[I2],
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
in,
wei,
out_device,
tunable,
nrepeat);
}
#endif
#if USE_CONV_FWD_V4R4_XDLOPS_NHWC
if(algo == ConvForwardAlgo::V4R4XDLNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable =
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk;
device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_olc<in_data_t,
acc_data_t,
out_data_t>(
handle,
tmp[I0],
tmp[I1],
tmp[I2],
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
in,
wei,
out_device,
tunable,
nrepeat);
}
#endif
if(do_verification) if(do_verification)
{ {
host_direct_convolution( host_direct_convolution(in,
in, wei, out_host, conv_strides, conv_dilations, in_left_pads, in_right_pads); wei,
out_host,
make_tuple(conv_stride_h, conv_stride_w),
make_tuple(conv_dilation_h, conv_dilation_w),
make_tuple(in_left_pad_h, in_left_pad_w),
make_tuple(in_right_pad_h, in_right_pad_w),
layout);
check_error(out_host, out_device); check_error(out_host, out_device);
......
...@@ -50,6 +50,146 @@ static tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r ...@@ -50,6 +50,146 @@ static tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r
{0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2}, {0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2},
5, 1}; 5, 1};
struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
{
ck::index_t BlockSize; // usually not tunable
ck::index_t MPerBlock;
ck::index_t NPerBlock;
ck::index_t KPerBlock;
ck::index_t MPerWave;
ck::index_t NPerWave;
ck::index_t K1;
ck::index_t MRepeat;
ck::index_t NRepeat;
std::array<ck::index_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> ABlockTransferSrcAccessOrder;
ck::index_t ABlockTransferSrcVectorDim;
ck::index_t ABlockTransferSrcScalarPerVector;
ck::index_t ABlockTransferDstScalarPerVector_K1;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> BBlockTransferSrcAccessOrder;
ck::index_t BBlockTransferSrcVectorDim;
ck::index_t BBlockTransferSrcScalarPerVector;
ck::index_t BBlockTransferDstScalarPerVector_K1;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 8> CThreadTransferSrcDstAccessOrder;
ck::index_t CThreadTransferSrcDstVectorDim;
ck::index_t CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = {
256, // BlockSize
128, // MPerBlock,
128, // NPerBlock,
4, // KPerBlock,
32, // MPerWave,
32, // NPerWave,
4, // K1,
2, // MRepeat,
2, // NRepeat,
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector,
4, // ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
{0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
false, // BThreadTransferSrcResetCoordinateAfterRun
{3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
7, // CThreadTransferSrcDstVectorDim,
1 // CThreadTransferDstScalarPerVector
};
struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
{
ck::index_t BlockSize; // usually not tunable
ck::index_t MPerBlock;
ck::index_t NPerBlock;
ck::index_t KPerBlock;
ck::index_t MPerWave;
ck::index_t NPerWave;
ck::index_t K1;
ck::index_t MRepeat;
ck::index_t NRepeat;
std::array<ck::index_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> ABlockTransferSrcAccessOrder;
ck::index_t ABlockTransferSrcVectorDim;
ck::index_t ABlockTransferSrcScalarPerVector;
ck::index_t ABlockTransferDstScalarPerVector_K1;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> BBlockTransferSrcAccessOrder;
ck::index_t BBlockTransferSrcVectorDim;
ck::index_t BBlockTransferSrcScalarPerVector;
ck::index_t BBlockTransferDstScalarPerVector_K1;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 8> CThreadTransferSrcDstAccessOrder;
ck::index_t CThreadTransferSrcDstVectorDim;
ck::index_t CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = {
256, // BlockSize
128, // MPerBlock,
128, // NPerBlock,
4, // KPerBlock,
32, // MPerWave,
32, // NPerWave,
4, // K1,
2, // MRepeat,
2, // NRepeat,
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
2, // ABlockTransferSrcVectorDim
4, // ABlockTransferSrcScalarPerVector,
4, // ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
{1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
2, // BBlockTransferSrcVectorDim
4, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
false, // BThreadTransferSrcResetCoordinateAfterRun
{2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
7, // CThreadTransferSrcDstVectorDim,
1 // CThreadTransferDstScalarPerVector
};
struct tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw struct tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw
{ {
ck::index_t BlockSize; ck::index_t BlockSize;
......
...@@ -273,6 +273,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx ...@@ -273,6 +273,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
GemmKPerBlock, GemmKPerBlock,
GemmMPerWave, GemmMPerWave,
GemmNPerWave, GemmNPerWave,
GemmK1,
MRepeat, MRepeat,
NRepeat, NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment