Unverified Commit fbdf4332 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Add xdlops v4r4r4 into online compilation (#48)



* init for v4r4 xdlops olc

* refactor wrap

* init impl of v4r4 nchw xdlops olc

* tuning

* test perf

* fixed v4r4 nhwc

* tuned v4r4 nhwc

* use gridwise_gemm_xdlops_v2r3

* swap a/b

* add pointer support into offline v2r3

* debugging v4r4r4 transform for olc

* change timer of olc

* refactor v4r4 xdlops nchw olc

* remove transform fun in v4r4 xdlops nhwc olc
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 0a72e4df
#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "driver_dynamic_gemm_xdlops_v1.hpp"
#include "driver_dynamic_gemm_xdlops_v2.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename FloatAB,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmKPack,
typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmKPack;
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_global_desc =
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0));
assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1));
assert(GemmK0 == in_gemmk0_gemmn_gemmk1_global_desc.GetLength(I0));
assert(GemmK0 == wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0));
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
// out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor
constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr auto out_m0_m1_m2_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc,
in_gemmk0_gemmn_gemmk1_global_desc,
out_m0_m1_m2_n_global_desc,
out_gemm_block_cluster_desc,
wei_gemmk0_gemmm_gemmk1_global_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_iterator_hacks,
out_m0_m1_m2_n_global_iterator_hacks,
wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks);
}
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename FloatAB,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmKPack,
typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmKPack;
assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0);
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0));
assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1));
assert(GemmK0 == in_gemmk0_gemmn_gemmk1_global_desc.GetLength(I0));
assert(GemmK0 == wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0));
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
// out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor
constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks =
Sequence<0, 1, 2, 0, 0>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr auto out_m0_m1_m2_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc,
in_gemmk0_gemmn_gemmk1_global_desc,
out_m0_m1_m2_n_global_desc,
out_gemm_block_cluster_desc,
wei_gemmk0_gemmm_gemmk1_global_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_iterator_hacks,
out_m0_m1_m2_n_global_iterator_hacks,
wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks,
in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks);
}
} // namespace ck
#endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2
#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops_v2.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t KPack,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_xdlops_v2(const FloatAB* p_a_global,
const FloatAB* p_b_global,
FloatC* p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
const auto K = a_k_m_global_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
if(!(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// GEMM
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGlobalDesc,
BGlobalDesc,
CGlobalDesc,
CBlockClusterDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
KPack,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks>;
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0;
const auto kernel = kernel_dynamic_gemm_xdlops_v2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0;
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
return ave_time;
#endif
}
} // namespace ck
#endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R2
#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R2
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_xdlops_v2r2.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
__host__ float driver_dynamic_gemm_xdlops_v2r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
using GridwiseGemm =
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r2<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
{
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{
throw std::runtime_error(
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r2 has invalid setting");
}
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const auto kernel = kernel_dynamic_gemm_xdlops_v2r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0M1M2NGridDesc>,
remove_reference_t<CBlockClusterAdaptor>>;
float ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor);
return ave_time;
}
} // namespace ck
#endif
...@@ -21,6 +21,7 @@ template <index_t BlockSize, ...@@ -21,6 +21,7 @@ template <index_t BlockSize,
index_t KPerBlock, index_t KPerBlock,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t K1,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1, typename ABlockTransferThreadSliceLengths_K0_M_K1,
...@@ -83,6 +84,7 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -83,6 +84,7 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
KPerBlock, KPerBlock,
MPerWave, MPerWave,
NPerWave, NPerWave,
K1,
MRepeat, MRepeat,
NRepeat, NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
...@@ -148,6 +150,7 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -148,6 +150,7 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t<CM0M1M2NGridDesc>, remove_reference_t<CM0M1M2NGridDesc>,
remove_reference_t<CBlockClusterAdaptor>>; remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel, float ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
...@@ -162,6 +165,32 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -162,6 +165,32 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void __CONSTANT__*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
#endif
return ave_time; return ave_time;
} }
......
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
...@@ -13,7 +13,7 @@ template <index_t BlockSize, ...@@ -13,7 +13,7 @@ template <index_t BlockSize,
class BBlockDesc, class BBlockDesc,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPack> index_t K1>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{ {
...@@ -32,7 +32,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -32,7 +32,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
static constexpr index_t MWaves = M1 / MPerWave; static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave; static constexpr index_t NWaves = N1 / NPerWave;
...@@ -119,18 +119,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -119,18 +119,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
"wrong! KPack dimension not consistent"); "wrong! K1 dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(KPack == BBlockDesc{}.GetLength(I3), "KPack is wrong!"); static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(KPack % xdlops_gemm.mfma_type.k_base == 0, "KPack is wrong!"); static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
...@@ -194,11 +194,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -194,11 +194,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1, Number<MRepeat>{}, I1, Number<KPack>{})); make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1, Number<NRepeat>{}, I1, Number<KPack>{})); make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
...@@ -207,20 +207,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -207,20 +207,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
FloatAB, FloatAB,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, MRepeat, 1, KPack>, Sequence<1, MRepeat, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
KPack, K1,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, NRepeat, 1, KPack>, Sequence<1, NRepeat, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
KPack, K1,
1>; 1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
...@@ -233,7 +233,7 @@ template <index_t BlockSize, ...@@ -233,7 +233,7 @@ template <index_t BlockSize,
class BBlockDesc, class BBlockDesc,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPack> index_t K1>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
{ {
...@@ -244,7 +244,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -244,7 +244,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, K1>{};
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
...@@ -339,18 +339,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -339,18 +339,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
"wrong! KPack dimension not consistent"); "wrong! K1 dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(KPack == BBlockDesc{}.GetLength(I3), "KPack is wrong!"); static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(KPack % xdlops_gemm.mfma_type.k_base == 0, "KPack is wrong!"); static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
...@@ -491,11 +491,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -491,11 +491,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1, Number<MRepeat>{}, I1, Number<KPack>{})); make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1, Number<NRepeat>{}, I1, Number<KPack>{})); make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
...@@ -504,20 +504,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -504,20 +504,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
FloatAB, FloatAB,
ABlockDesc, ABlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>, Sequence<1, 1, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
1, // KPack, 1, // K1,
1>; 1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
BBlockDesc, BBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>, Sequence<1, 1, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, 3,
1, // KPack, 1, // K1,
1>; 1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
namespace ck { namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -45,6 +46,50 @@ __global__ void ...@@ -45,6 +46,50 @@ __global__ void
c_m0_m1_m2_n_grid_desc, c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CM0M1M2NGridDesc,
typename CBlockClusterAdaptor>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k0_m_k1_grid_desc,
const void __CONSTANT__* p_b_k0_n_k1_grid_desc,
const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc,
const void __CONSTANT__* p_c_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
const auto b_k0_n_k1_grid_desc =
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto c_block_cluster_adaptor =
*reinterpret_cast<const CBlockClusterAdaptor*>((const void*)p_c_block_cluster_adaptor);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor);
}
#endif
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -59,6 +104,7 @@ template <index_t BlockSize, ...@@ -59,6 +104,7 @@ template <index_t BlockSize,
index_t KPerBlock, index_t KPerBlock,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t K1Value,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1, typename ABlockTransferThreadSliceLengths_K0_M_K1,
...@@ -94,7 +140,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -94,7 +140,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -160,7 +206,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -160,7 +206,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const auto M = c_m_n_grid_desc.GetLength(I0); const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1); const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1.value>{}; constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout(); constexpr auto CLayout = xdlops_gemm.GetCLayout();
...@@ -267,7 +313,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -267,7 +313,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock, K1.value>, Sequence<KPerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -294,7 +340,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -294,7 +340,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock, K1.value>, Sequence<KPerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
...@@ -354,7 +400,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -354,7 +400,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
decltype(b_k0_n0_n1_k1_block_desc), decltype(b_k0_n0_n1_k1_block_desc),
MPerWave, MPerWave,
NPerWave, NPerWave,
K1.value>{}; K1>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout(); constexpr auto CLayout = blockwise_gemm.GetCLayout();
......
...@@ -39,7 +39,6 @@ ...@@ -39,7 +39,6 @@
#if CK_USE_AMD_XDLOPS #if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp" #include "amd_xdlops.hpp"
#include "amd_xdlops_inline_asm.hpp"
#endif #endif
#endif #endif
...@@ -92,6 +92,8 @@ set(MCONV_KERNEL_INCLUDES ...@@ -92,6 +92,8 @@ set(MCONV_KERNEL_INCLUDES
set(MCONV_KERNELS set(MCONV_KERNELS
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp ../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp ../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
) )
add_kernels("olCompiling/" "${MCONV_KERNELS}") add_kernels("olCompiling/" "${MCONV_KERNELS}")
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
#define USE_CONV_BWD_V4R1_XDL_NHWC 0 #define USE_CONV_BWD_V4R1_XDL_NHWC 1
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 #define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
enum ConvBackwardDataAlgo enum ConvBackwardDataAlgo
......
...@@ -19,8 +19,6 @@ ...@@ -19,8 +19,6 @@
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
...@@ -30,9 +28,7 @@ ...@@ -30,9 +28,7 @@
#define USE_CONV_FWD_V4R5_NCHW 0 #define USE_CONV_FWD_V4R5_NCHW 0
#define USE_CONV_FWD_V4R5R2_NCHW 0 #define USE_CONV_FWD_V4R5R2_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4_XDL_NCHW 1 #define USE_CONV_FWD_V4R4R2_XDL_NCHW 1
#define USE_CONV_FWD_V4R4R2_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R3_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
enum ConvForwardAlgo enum ConvForwardAlgo
...@@ -43,10 +39,8 @@ enum ConvForwardAlgo ...@@ -43,10 +39,8 @@ enum ConvForwardAlgo
V4R5NCHW, // 3 V4R5NCHW, // 3
V4R5R2NCHW, // 4 V4R5R2NCHW, // 4
V5R1NCHW, // 5 V5R1NCHW, // 5
V4R4XDLNCHW, // 6 V4R4R2XDLNCHW, // 6
V4R4R2XDLNHWC, // 7 V4R4R4XDLNHWC // 7
V4R4R3XDLNHWC, // 8
V4R4R4XDLNHWC // 9
}; };
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -462,8 +456,8 @@ int main(int argc, char* argv[]) ...@@ -462,8 +456,8 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_FWD_V4R4_XDL_NCHW #if USE_CONV_FWD_V4R4R2_XDL_NCHW
if(algo == ConvForwardAlgo::V4R4XDLNCHW) if(algo == ConvForwardAlgo::V4R4R2XDLNCHW)
{ {
if(layout != ConvTensorLayout::NCHW) if(layout != ConvTensorLayout::NCHW)
{ {
...@@ -489,60 +483,6 @@ int main(int argc, char* argv[]) ...@@ -489,60 +483,6 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_FWD_V4R4R2_XDL_NHWC
if(algo == ConvForwardAlgo::V4R4R2XDLNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V4R4R3_XDL_NHWC
if(algo == ConvForwardAlgo::V4R4R3XDLNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V4R4R4_XDL_NHWC #if USE_CONV_FWD_V4R4R4_XDL_NHWC
if(algo == ConvForwardAlgo::V4R4R4XDLNHWC) if(algo == ConvForwardAlgo::V4R4R4XDLNHWC)
{ {
......
...@@ -12,11 +12,17 @@ ...@@ -12,11 +12,17 @@
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv.hpp" #include "host_conv.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp" #include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_CONV_FWD_V4R4_NCHW 1 #define USE_CONV_FWD_V4R4_NCHW 1
#define USE_CONV_FWD_V4R5_NCHW 1 #define USE_CONV_FWD_V4R5_NCHW 1
#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1
#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1
#include "conv_tunables.hpp" #include "conv_tunables.hpp"
#include "handle.hpp" #include "handle.hpp"
...@@ -24,10 +30,10 @@ ...@@ -24,10 +30,10 @@
enum ConvForwardAlgo enum ConvForwardAlgo
{ {
V4R4NCHW, V4R4NCHW, // 0
V4R4NHWC, V4R5NCHW, // 1
V4R5NCHW, V4R4XDLNCHW, // 2
V5R1NCHW V4R4XDLNHWC // 3
}; };
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -105,14 +111,16 @@ int main(int argc, char* argv[]) ...@@ -105,14 +111,16 @@ int main(int argc, char* argv[])
{ {
case ConvTensorLayout::NCHW: case ConvTensorLayout::NCHW:
// NCHW // NCHW
in_lengths_host[0] = static_cast<std::size_t>(N); in_lengths_host[0] = static_cast<std::size_t>(N);
in_lengths_host[1] = static_cast<std::size_t>(C); in_lengths_host[1] = static_cast<std::size_t>(C);
in_lengths_host[2] = static_cast<std::size_t>(Hi); in_lengths_host[2] = static_cast<std::size_t>(Hi);
in_lengths_host[3] = static_cast<std::size_t>(Wi); in_lengths_host[3] = static_cast<std::size_t>(Wi);
wei_lengths_host[0] = static_cast<std::size_t>(K); wei_lengths_host[0] = static_cast<std::size_t>(K);
wei_lengths_host[1] = static_cast<std::size_t>(C); wei_lengths_host[1] = static_cast<std::size_t>(C);
wei_lengths_host[2] = static_cast<std::size_t>(Y); wei_lengths_host[2] = static_cast<std::size_t>(Y);
wei_lengths_host[3] = static_cast<std::size_t>(X); wei_lengths_host[3] = static_cast<std::size_t>(X);
out_lengths_host[0] = static_cast<std::size_t>(N); out_lengths_host[0] = static_cast<std::size_t>(N);
out_lengths_host[1] = static_cast<std::size_t>(K); out_lengths_host[1] = static_cast<std::size_t>(K);
out_lengths_host[2] = static_cast<std::size_t>(Ho); out_lengths_host[2] = static_cast<std::size_t>(Ho);
...@@ -120,14 +128,16 @@ int main(int argc, char* argv[]) ...@@ -120,14 +128,16 @@ int main(int argc, char* argv[])
break; break;
case ConvTensorLayout::NHWC: case ConvTensorLayout::NHWC:
// NHWC // NHWC
in_lengths_host[0] = static_cast<std::size_t>(N); in_lengths_host[0] = static_cast<std::size_t>(N);
in_lengths_host[1] = static_cast<std::size_t>(Hi); in_lengths_host[1] = static_cast<std::size_t>(Hi);
in_lengths_host[2] = static_cast<std::size_t>(Wi); in_lengths_host[2] = static_cast<std::size_t>(Wi);
in_lengths_host[3] = static_cast<std::size_t>(C); in_lengths_host[3] = static_cast<std::size_t>(C);
wei_lengths_host[0] = static_cast<std::size_t>(K); wei_lengths_host[0] = static_cast<std::size_t>(K);
wei_lengths_host[1] = static_cast<std::size_t>(Y); wei_lengths_host[1] = static_cast<std::size_t>(Y);
wei_lengths_host[2] = static_cast<std::size_t>(X); wei_lengths_host[2] = static_cast<std::size_t>(X);
wei_lengths_host[3] = static_cast<std::size_t>(C); wei_lengths_host[3] = static_cast<std::size_t>(C);
out_lengths_host[0] = static_cast<std::size_t>(N); out_lengths_host[0] = static_cast<std::size_t>(N);
out_lengths_host[1] = static_cast<std::size_t>(Ho); out_lengths_host[1] = static_cast<std::size_t>(Ho);
out_lengths_host[2] = static_cast<std::size_t>(Wo); out_lengths_host[2] = static_cast<std::size_t>(Wo);
...@@ -271,10 +281,80 @@ int main(int argc, char* argv[]) ...@@ -271,10 +281,80 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_FWD_V4R4_XDLOPS_NCHW
if(algo == ConvForwardAlgo::V4R4XDLNCHW)
{
if(layout != ConvTensorLayout::NCHW)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nchw();
tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable =
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw;
device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_olc<in_data_t,
acc_data_t,
out_data_t>(
handle,
tmp[I0],
tmp[I1],
tmp[I2],
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
in,
wei,
out_device,
tunable,
nrepeat);
}
#endif
#if USE_CONV_FWD_V4R4_XDLOPS_NHWC
if(algo == ConvForwardAlgo::V4R4XDLNHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable =
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk;
device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_olc<in_data_t,
acc_data_t,
out_data_t>(
handle,
tmp[I0],
tmp[I1],
tmp[I2],
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
in,
wei,
out_device,
tunable,
nrepeat);
}
#endif
if(do_verification) if(do_verification)
{ {
host_direct_convolution( host_direct_convolution(in,
in, wei, out_host, conv_strides, conv_dilations, in_left_pads, in_right_pads); wei,
out_host,
make_tuple(conv_stride_h, conv_stride_w),
make_tuple(conv_dilation_h, conv_dilation_w),
make_tuple(in_left_pad_h, in_left_pad_w),
make_tuple(in_right_pad_h, in_right_pad_w),
layout);
check_error(out_host, out_device); check_error(out_host, out_device);
......
...@@ -50,6 +50,146 @@ static tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r ...@@ -50,6 +50,146 @@ static tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r
{0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2}, {0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2},
5, 1}; 5, 1};
struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
{
ck::index_t BlockSize; // usually not tunable
ck::index_t MPerBlock;
ck::index_t NPerBlock;
ck::index_t KPerBlock;
ck::index_t MPerWave;
ck::index_t NPerWave;
ck::index_t K1;
ck::index_t MRepeat;
ck::index_t NRepeat;
std::array<ck::index_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> ABlockTransferSrcAccessOrder;
ck::index_t ABlockTransferSrcVectorDim;
ck::index_t ABlockTransferSrcScalarPerVector;
ck::index_t ABlockTransferDstScalarPerVector_K1;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> BBlockTransferSrcAccessOrder;
ck::index_t BBlockTransferSrcVectorDim;
ck::index_t BBlockTransferSrcScalarPerVector;
ck::index_t BBlockTransferDstScalarPerVector_K1;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 8> CThreadTransferSrcDstAccessOrder;
ck::index_t CThreadTransferSrcDstVectorDim;
ck::index_t CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = {
256, // BlockSize
128, // MPerBlock,
128, // NPerBlock,
4, // KPerBlock,
32, // MPerWave,
32, // NPerWave,
4, // K1,
2, // MRepeat,
2, // NRepeat,
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector,
4, // ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
{0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
false, // BThreadTransferSrcResetCoordinateAfterRun
{3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
7, // CThreadTransferSrcDstVectorDim,
1 // CThreadTransferDstScalarPerVector
};
struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
{
ck::index_t BlockSize; // usually not tunable
ck::index_t MPerBlock;
ck::index_t NPerBlock;
ck::index_t KPerBlock;
ck::index_t MPerWave;
ck::index_t NPerWave;
ck::index_t K1;
ck::index_t MRepeat;
ck::index_t NRepeat;
std::array<ck::index_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> ABlockTransferSrcAccessOrder;
ck::index_t ABlockTransferSrcVectorDim;
ck::index_t ABlockTransferSrcScalarPerVector;
ck::index_t ABlockTransferDstScalarPerVector_K1;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> BBlockTransferSrcAccessOrder;
ck::index_t BBlockTransferSrcVectorDim;
ck::index_t BBlockTransferSrcScalarPerVector;
ck::index_t BBlockTransferDstScalarPerVector_K1;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 8> CThreadTransferSrcDstAccessOrder;
ck::index_t CThreadTransferSrcDstVectorDim;
ck::index_t CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = {
256, // BlockSize
128, // MPerBlock,
128, // NPerBlock,
4, // KPerBlock,
32, // MPerWave,
32, // NPerWave,
4, // K1,
2, // MRepeat,
2, // NRepeat,
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
2, // ABlockTransferSrcVectorDim
4, // ABlockTransferSrcScalarPerVector,
4, // ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
{1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
2, // BBlockTransferSrcVectorDim
4, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
false, // BThreadTransferSrcResetCoordinateAfterRun
{2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
7, // CThreadTransferSrcDstVectorDim,
1 // CThreadTransferDstScalarPerVector
};
struct tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw struct tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw
{ {
ck::index_t BlockSize; ck::index_t BlockSize;
......
...@@ -273,6 +273,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx ...@@ -273,6 +273,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
GemmKPerBlock, GemmKPerBlock,
GemmMPerWave, GemmMPerWave,
GemmNPerWave, GemmNPerWave,
GemmK1,
MRepeat, MRepeat,
NRepeat, NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment