Unverified Commit 01055d95 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

No raw index calculation (#31)



* Replace most raw index calculation to coordinate transformation
* Overhaul blockwise and threadwise GEMM
* Overhaul driver for gridwies GEMM kernel
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent d075adf1
...@@ -4,1532 +4,441 @@ ...@@ -4,1532 +4,441 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp" #include "driver_dynamic_gemm_v1.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck { namespace ck {
// GemmM = K // GemmM = K
// GemmN = N * Ho * Wo // GemmN = N * Ho * Wo
// GemmK = C * Y * X // GemmK = C * Y * X
template <index_t BlockSize, template <index_t GemmMPerBlock,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmM1,
index_t GemmMPerThread, index_t GemmN1,
index_t GemmNPerThread, typename... Wei,
index_t GemmKPerThread, typename... In,
index_t GemmMLevel0Cluster, typename... Out,
index_t GemmNLevel0Cluster, typename ConvStrides,
index_t GemmMLevel1Cluster, typename ConvDilations,
index_t GemmNLevel1Cluster, typename InLeftPads,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename InRightPads>
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
index_t GemmABlockTransferSrcScalarPerVector_GemmK, const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
index_t GemmABlockTransferDstScalarPerVector_GemmM, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, const ConvStrides& conv_strides,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN, const ConvDilations& conv_dilations,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, const InLeftPads& in_left_pads,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1> const InRightPads& in_right_pads)
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{ {
template <typename... Wei, constexpr auto I0 = Number<0>{};
typename... In, constexpr auto I1 = Number<1>{};
typename... Out, constexpr auto I2 = Number<2>{};
typename ConvStrides, constexpr auto I3 = Number<3>{};
typename ConvDilations,
typename InLeftPads, const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
typename InRightPads> const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc, const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const ConvStrides& conv_strides, const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const InRightPads& in_right_pads, const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
FloatC* __restrict__ p_out_global) const const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
{
constexpr auto I0 = Number<0>{}; const auto ConvStrideH = conv_strides[I0];
constexpr auto I1 = Number<1>{}; const auto ConvStrideW = conv_strides[I1];
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); const auto InLeftPadH = in_left_pads[I0];
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); const auto InLeftPadW = in_left_pads[I1];
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); const auto InRightPadH = in_right_pads[I0];
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); const auto InRightPadW = in_right_pads[I1];
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); // weight tensor
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
const auto X = wei_k_c_y_x_global_desc.GetLength(I3); make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1]; // input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
const auto ConvDilationH = conv_dilations[I0]; in_n_c_hi_wi_global_desc,
const auto ConvDilationW = conv_dilations[I1]; make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
const auto InLeftPadH = in_left_pads[I0]; make_pad_transform(Hi, InLeftPadH, InRightPadH),
const auto InLeftPadW = in_left_pads[I1]; make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
const auto InRightPadH = in_right_pads[I0]; make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto InRightPadW = in_right_pads[I1];
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
// weight tensor in_n_c_hip_wip_global_desc,
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( make_tuple(make_pass_through_transform(N),
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), make_pass_through_transform(C),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( const auto in_gemmk_gemmn_global_desc =
in_n_c_hi_wi_global_desc, transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_pad_transform(Wi, InLeftPadW, InRightPadW)), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); // output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
in_n_c_hip_wip_global_desc, make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple( make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_pass_through_transform(N), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, const auto GemmM0 = GemmM / Number<GemmM1>{};
make_tuple(make_merge_transform(make_tuple(C, Y, X)), const auto GemmN0 = GemmN / Number<GemmN1>{};
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
make_tuple(Sequence<0>{}, Sequence<1>{})); out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
// output tensor make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( make_tuple(Sequence<0>{}, Sequence<1>{}),
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
make_tuple(make_pass_through_transform(K),
make_merge_transform(make_tuple(N, Ho * Wo))), // out_gemm_block_cluster_desc
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); // hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0)) constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
{
throw std::runtime_error("wrong! GEMM size no divisible"); // hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
} constexpr auto in_gemmk_gemmn_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{}; make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
const auto GemmM0 = GemmM / GemmM1;
const auto GemmN0 = GemmN / GemmN1; constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor( // hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
out_gemmm_gemmn_global_desc, // tensor hack for NKHW format
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_unmerge_transform(make_tuple(GemmN0, GemmN1))), make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<0, 0, 0, 0, 0>{},
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
// hack to control index calculation when iterating over a_k_m_global tensor make_tuple(Sequence<0, 0, 0, 0, 0>{},
constexpr auto a_k_m_global_iterator_hacks = Sequence<0, 0, 0, 0, 0>{},
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), Sequence<0, 0, 2, 0, 0>{},
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{}));
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; return make_tuple(wei_gemmk_gemmm_global_desc,
in_gemmk_gemmn_global_desc,
// hack to control index calculation when iterating over b_k_n_global tensor out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
constexpr auto b_k_n_global_iterator_hacks = out_gemm_block_cluster_desc,
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, wei_gemmk_gemmm_global_iterator_hacks,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), in_gemmk_gemmn_global_iterator_hacks,
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
constexpr auto b_k_n_global_move_slice_window_iterator_hack = }
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
};
// GemmM = K // GemmM = K
// GemmN = N * Ho * Wo // GemmN = N * Ho * Wo
// GemmK = C * Y * X // GemmK = C * Y * X
template <index_t BlockSize, template <index_t GemmMPerBlock,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmM1,
index_t GemmMPerThread, index_t GemmN1,
index_t GemmNPerThread, typename... Wei,
index_t GemmKPerThread, typename... In,
index_t GemmMLevel0Cluster, typename... Out,
index_t GemmNLevel0Cluster, typename ConvStrides,
index_t GemmMLevel1Cluster, typename ConvDilations,
index_t GemmNLevel1Cluster, typename InLeftPads,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename InRightPads>
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, __host__ __device__ constexpr auto
index_t GemmABlockTransferSrcScalarPerVector_GemmK, transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
index_t GemmABlockTransferDstScalarPerVector_GemmM, const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN, const ConvStrides& conv_strides,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, const ConvDilations& conv_dilations,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1> const InLeftPads& in_left_pads,
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad const InRightPads& in_right_pads)
{ {
template <typename... Wei, constexpr auto I0 = Number<0>{};
typename... In, constexpr auto I1 = Number<1>{};
typename... Out, constexpr auto I2 = Number<2>{};
typename ConvStrides, constexpr auto I3 = Number<3>{};
typename ConvDilations,
typename InLeftPads, const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
typename InRightPads> const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc, const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const ConvStrides& conv_strides, const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const InRightPads& in_right_pads, const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
FloatC* __restrict__ p_out_global) const const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
{
constexpr auto I0 = Number<0>{}; const auto ConvStrideH = conv_strides[I0];
constexpr auto I1 = Number<1>{}; const auto ConvStrideW = conv_strides[I1];
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); const auto InLeftPadH = in_left_pads[I0];
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); const auto InLeftPadW = in_left_pads[I1];
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); const auto InRightPadH = in_right_pads[I0];
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); const auto InRightPadW = in_right_pads[I1];
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
// weight tensor
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
const auto X = wei_k_c_y_x_global_desc.GetLength(I3); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
const auto ConvStrideH = conv_strides[I0]; make_tuple(Sequence<0>{}, Sequence<1>{}),
const auto ConvStrideW = conv_strides[I1]; make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto ConvDilationH = conv_dilations[I0]; // input tensor
const auto ConvDilationW = conv_dilations[I1]; const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
const auto InLeftPadH = in_left_pads[I0]; make_tuple(make_pass_through_transform(N),
const auto InLeftPadW = in_left_pads[I1]; make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
const auto InRightPadH = in_right_pads[I0]; make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
const auto InRightPadW = in_right_pads[I1]; make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
if(!(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0))
{ const auto in_gemmk_gemmn_global_desc =
throw std::runtime_error("wrong! no padding"); transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
} make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
// weight tensor make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( make_tuple(Sequence<0>{}, Sequence<1>{}));
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), // output tensor
make_tuple(Sequence<0>{}, Sequence<1>{}), const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_tuple(Sequence<1>{}, Sequence<0>{})); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
// input tensor make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( make_tuple(Sequence<0>{}, Sequence<1>{}));
in_n_c_hi_wi_global_desc,
make_tuple( const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
make_pass_through_transform(N), const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
make_pass_through_transform(C), const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); const auto GemmM0 = GemmM / Number<GemmM1>{};
const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
make_tuple(make_merge_transform(make_tuple(C, Y, X)), out_gemmm_gemmn_global_desc,
make_merge_transform(make_tuple(N, Ho, Wo))), make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( // out_gemm_block_cluster_desc
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(make_pass_through_transform(K), make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), // hack to control index calculation when iterating over a_k_m_global tensor
make_tuple(Sequence<0>{}, Sequence<1>{})); constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && // hack to control index calculation when iterating over b_k_n_global tensor
GemmK % GemmKPerBlock == 0)) constexpr auto in_gemmk_gemmn_global_iterator_hacks =
{ make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}),
throw std::runtime_error("wrong! GEMM size no divisible"); make_tuple(Sequence<0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 2>{}));
}
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{}; Sequence<0, 0, 0, 0, 0, 1, 2>{};
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
const auto GemmM0 = GemmM / GemmM1; // hack for NKHW format
const auto GemmN0 = GemmN / GemmN1; constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = Sequence<0, 0, 0, 0, 0>{},
transform_dynamic_tensor_descriptor( Sequence<0, 0, 1, 0, 0>{},
out_gemmm_gemmn_global_desc, Sequence<0, 0, 1, 0, 0>{}),
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), make_tuple(Sequence<0, 0, 0, 0, 0>{},
make_unmerge_transform(make_tuple(GemmN0, GemmN1))), Sequence<0, 0, 0, 0, 0>{},
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<0, 0, 2, 0, 0>{},
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); Sequence<0, 0, 2, 0, 0>{}));
// hack to control index calculation when iterating over a_k_m_global tensor return make_tuple(wei_gemmk_gemmm_global_desc,
constexpr auto a_k_m_global_iterator_hacks = in_gemmk_gemmn_global_desc,
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); out_gemm_block_cluster_desc,
wei_gemmk_gemmm_global_iterator_hacks,
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; in_gemmk_gemmn_global_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
// hack to control index calculation when iterating over b_k_n_global tensor wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
constexpr auto b_k_n_global_iterator_hacks = make_tuple( in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}), }
make_tuple(Sequence<0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 2>{}));
// GemmM = K
constexpr auto b_k_n_global_move_slice_window_iterator_hack = // GemmN = N * Ho * Wo
Sequence<0, 0, 0, 0, 0, 1, 2>{}; // GemmK = C * Y * X
template <index_t GemmMPerBlock,
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
};
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmM1,
index_t GemmMPerThread, index_t GemmN1,
index_t GemmNPerThread, typename... Wei,
index_t GemmKPerThread, typename... In,
index_t GemmMLevel0Cluster, typename... Out,
index_t GemmNLevel0Cluster, typename ConvStrides,
index_t GemmMLevel1Cluster, typename ConvDilations,
index_t GemmNLevel1Cluster, typename InLeftPads,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename InRightPads>
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1(
index_t GemmABlockTransferSrcScalarPerVector_GemmK, const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
index_t GemmABlockTransferDstScalarPerVector_GemmM, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, const ConvStrides& conv_strides,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN, const ConvDilations& conv_dilations,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, const InLeftPads& in_left_pads,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1> const InRightPads& in_right_pads)
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
{ {
template <typename... Wei, constexpr auto I0 = Number<0>{};
typename... In, constexpr auto I1 = Number<1>{};
typename... Out, constexpr auto I2 = Number<2>{};
typename ConvStrides, constexpr auto I3 = Number<3>{};
typename ConvDilations,
typename InLeftPads, const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
typename InRightPads> const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc, const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const ConvStrides& conv_strides, const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const InRightPads& in_right_pads, const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
FloatC* __restrict__ p_out_global) const const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
{
constexpr auto I0 = Number<0>{}; const auto ConvStrideH = conv_strides[I0];
constexpr auto I1 = Number<1>{}; const auto ConvStrideW = conv_strides[I1];
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); const auto InLeftPadH = in_left_pads[I0];
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); const auto InLeftPadW = in_left_pads[I1];
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); const auto InRightPadH = in_right_pads[I0];
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); const auto InRightPadW = in_right_pads[I1];
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3); // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
const auto ConvStrideH = conv_strides[I0]; make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
const auto ConvStrideW = conv_strides[I1]; make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
const auto ConvDilationH = conv_dilations[I0]; make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto ConvDilationW = conv_dilations[I1];
// input tensor
const auto InLeftPadH = in_left_pads[I0]; const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
const auto InLeftPadW = in_left_pads[I1]; in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
const auto InRightPadH = in_right_pads[I0]; make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
const auto InRightPadW = in_right_pads[I1]; make_tuple(Sequence<0>{}, Sequence<1>{}));
if(!(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && // output tensor
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
InRightPadW == 0)) make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
{ make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
throw std::runtime_error("wrong! 1x1, stride 1, no padding"); make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
} make_tuple(Sequence<0>{}, Sequence<1>{}));
// weight tensor const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto GemmM0 = GemmM / Number<GemmM1>{};
// input tensor const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc, const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), out_gemmm_gemmn_global_desc,
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
// output tensor make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), // out_gemm_block_cluster_desc
make_tuple(make_pass_through_transform(K), const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_merge_transform(make_tuple(N, Ho * Wo))), make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); // hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0)) // hack to control index calculation when iterating over b_k_n_global tensor
{ constexpr auto in_gemmk_gemmn_global_iterator_hacks =
throw std::runtime_error("wrong! GEMM size no divisible"); make_tuple(make_tuple(Sequence<0, 1, 0>{}, Sequence<0, 0, 1>{}),
} make_tuple(Sequence<0, 2, 0>{}, Sequence<0, 0, 2>{}));
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{}; constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = Sequence<0, 1, 2>{};
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
const auto GemmM0 = GemmM / GemmM1; constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
const auto GemmN0 = GemmN / GemmN1; make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = Sequence<0, 0, 1, 0, 0>{},
transform_dynamic_tensor_descriptor( Sequence<0, 0, 1, 0, 0>{}),
out_gemmm_gemmn_global_desc, make_tuple(Sequence<0, 0, 0, 0, 0>{},
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), Sequence<0, 0, 0, 0, 0>{},
make_unmerge_transform(make_tuple(GemmN0, GemmN1))), Sequence<0, 0, 2, 0, 0>{},
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<0, 0, 2, 0, 0>{}));
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return make_tuple(wei_gemmk_gemmm_global_desc,
// hack to control index calculation when iterating over a_k_m_global tensor in_gemmk_gemmn_global_desc,
constexpr auto a_k_m_global_iterator_hacks = out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), out_gemm_block_cluster_desc,
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); wei_gemmk_gemmm_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks,
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
// hack to control index calculation when iterating over b_k_n_global tensor in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
constexpr auto b_k_n_global_iterator_hacks = }
make_tuple(make_tuple(Sequence<0, 1, 0>{}, Sequence<0, 0, 1>{}),
make_tuple(Sequence<0, 2, 0>{}, Sequence<0, 0, 2>{}));
constexpr auto b_k_n_global_move_slice_window_iterator_hack = Sequence<0, 1, 2>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -4,1015 +4,297 @@ ...@@ -4,1015 +4,297 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp" #include "driver_dynamic_gemm_v1.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck { namespace ck {
// GemmM = K // GemmM = K
// GemmN = N * Ho * Wo // GemmN = N * Ho * Wo
// GemmK = Y * X * C // GemmK = C * Y * X
template <index_t BlockSize, template <index_t GemmMPerBlock,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmM1,
index_t GemmMPerThread, index_t GemmN1,
index_t GemmNPerThread, typename... Wei,
index_t GemmKPerThread, typename... In,
index_t GemmMLevel0Cluster, typename... Out,
index_t GemmNLevel0Cluster, typename ConvStrides,
index_t GemmMLevel1Cluster, typename ConvDilations,
index_t GemmNLevel1Cluster, typename InLeftPads,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename InRightPads>
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(
index_t GemmABlockTransferSrcScalarPerVector_GemmK, const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_global_desc,
index_t GemmABlockTransferDstScalarPerVector_GemmM, const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_global_desc,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_global_desc,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, const ConvStrides& conv_strides,
index_t GemmBBlockTransferSrcScalarPerVector_GemmK, const ConvDilations& conv_dilations,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, const InLeftPads& in_left_pads,
index_t GemmCThreadTransferDstScalarPerVector_GemmM1> const InRightPads& in_right_pads)
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
{ {
template <typename... Wei, constexpr auto I0 = Number<0>{};
typename... In, constexpr auto I1 = Number<1>{};
typename... Out, constexpr auto I2 = Number<2>{};
typename ConvStrides, constexpr auto I3 = Number<3>{};
typename ConvDilations,
typename InLeftPads, const auto N = in_n_hi_wi_c_global_desc.GetLength(I0);
typename InRightPads> const auto C = in_n_hi_wi_c_global_desc.GetLength(I3);
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_global_desc, const auto K = out_n_ho_wo_k_global_desc.GetLength(I3);
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_global_desc, const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1);
const ConvStrides& conv_strides, const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2);
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1);
const InRightPads& in_right_pads, const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2);
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const auto Y = wei_k_y_x_c_global_desc.GetLength(I1);
FloatC* __restrict__ p_out_global) const const auto X = wei_k_y_x_c_global_desc.GetLength(I2);
{
constexpr auto I0 = Number<0>{}; const auto ConvStrideH = conv_strides[I0];
constexpr auto I1 = Number<1>{}; const auto ConvStrideW = conv_strides[I1];
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto N = in_n_hi_wi_c_global_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); const auto InLeftPadH = in_left_pads[I0];
const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); const auto InLeftPadW = in_left_pads[I1];
const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1); const auto InRightPadH = in_right_pads[I0];
const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2); const auto InRightPadW = in_right_pads[I1];
const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1); // weight tensor
const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2); const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
const auto Y = wei_k_y_x_c_global_desc.GetLength(I1); make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
const auto X = wei_k_y_x_c_global_desc.GetLength(I2); make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1]; // input tensor
const auto in_n_hip_wip_c_global_desc = transform_dynamic_tensor_descriptor(
const auto ConvDilationH = conv_dilations[I0]; in_n_hi_wi_c_global_desc,
const auto ConvDilationW = conv_dilations[I1]; make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
const auto InLeftPadH = in_left_pads[I0]; make_pad_transform(Wi, InLeftPadW, InRightPadW),
const auto InLeftPadW = in_left_pads[I1]; make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
const auto InRightPadH = in_right_pads[I0]; make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto InRightPadW = in_right_pads[I1];
const auto in_n_y_ho_x_wo_c_global_desc = transform_dynamic_tensor_descriptor(
// weight tensor in_n_hip_wip_c_global_desc,
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( make_tuple(make_pass_through_transform(N),
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_pass_through_transform(C)),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// input tensor
const auto in_n_hip_wip_c_global_desc = transform_dynamic_tensor_descriptor( const auto in_gemmk_gemmn_global_desc =
in_n_hi_wi_c_global_desc, transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_global_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_merge_transform(make_tuple(N, Ho, Wo))),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); // output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
const auto in_n_y_ho_x_wo_c_global_desc = transform_dynamic_tensor_descriptor( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
in_n_hip_wip_c_global_desc, make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple( make_tuple(Sequence<0>{}, Sequence<1>{}),
make_pass_through_transform(N), make_tuple(Sequence<1>{}, Sequence<0>{}));
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
make_pass_through_transform(C)), const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_y_ho_x_wo_c_global_desc, const auto GemmM0 = GemmM / Number<GemmM1>{};
make_tuple(make_merge_transform(make_tuple(Y, X, C)), const auto GemmN0 = GemmN / Number<GemmN1>{};
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
make_tuple(Sequence<0>{}, Sequence<1>{})); out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
// output tensor make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( make_tuple(Sequence<0>{}, Sequence<1>{}),
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), // out_gemm_block_cluster_desc
make_tuple(Sequence<1>{}, Sequence<0>{})); const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); // hack to control index calculation when iterating over a_k_m_global tensor
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
GemmK % GemmKPerBlock == 0))
{ constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
throw std::runtime_error("wrong! GEMM size no divisible");
} // hack to control index calculation when iterating over b_k_n_global tensor
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{}; make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
const auto GemmM0 = GemmM / GemmM1; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
const auto GemmN0 = GemmN / GemmN1;
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc, // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), // hack for NKHW format
make_unmerge_transform(make_tuple(GemmN0, GemmN1))), constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
// hack to control index calculation when iterating over a_k_m_global tensor Sequence<0, 0, 1, 0, 0>{}),
constexpr auto a_k_m_global_iterator_hacks = make_tuple(Sequence<0, 0, 0, 0, 0>{},
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), Sequence<0, 0, 0, 0, 0>{},
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
return make_tuple(wei_gemmk_gemmm_global_desc,
// hack to control index calculation when iterating over b_k_n_global tensor in_gemmk_gemmn_global_desc,
constexpr auto b_k_n_global_iterator_hacks = out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, out_gemm_block_cluster_desc,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), wei_gemmk_gemmm_global_iterator_hacks,
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, in_gemmk_gemmn_global_iterator_hacks,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
constexpr auto b_k_n_global_move_slice_window_iterator_hack = in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; }
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // GemmM = K
// hack for NKHW format // GemmN = N * Ho * Wo
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = // GemmK = C * Y * X
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, template <index_t GemmMPerBlock,
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmBBlockTransferSrcScalarPerVector_GemmK,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
1,
GemmCThreadTransferDstScalarPerVector_GemmM1,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize);
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
};
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmM1,
index_t GemmMPerThread, index_t GemmN1,
index_t GemmNPerThread, typename... Wei,
index_t GemmKPerThread, typename... In,
index_t GemmMLevel0Cluster, typename... Out,
index_t GemmNLevel0Cluster, typename ConvStrides,
index_t GemmMLevel1Cluster, typename ConvDilations,
index_t GemmNLevel1Cluster, typename InLeftPads,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename InRightPads>
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(
index_t GemmABlockTransferSrcScalarPerVector_GemmK, const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_global_desc,
index_t GemmABlockTransferDstScalarPerVector_GemmM, const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_global_desc,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_global_desc,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, const ConvStrides& conv_strides,
index_t GemmBBlockTransferSrcScalarPerVector_GemmK, const ConvDilations& conv_dilations,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, const InLeftPads& in_left_pads,
index_t GemmCThreadTransferDstScalarPerVector_GemmM1> const InRightPads& in_right_pads)
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
{ {
template <typename... Wei, constexpr auto I0 = Number<0>{};
typename... In, constexpr auto I1 = Number<1>{};
typename... Out, constexpr auto I2 = Number<2>{};
typename ConvStrides, constexpr auto I3 = Number<3>{};
typename ConvDilations,
typename InLeftPads, const auto N = in_n_hi_wi_c_global_desc.GetLength(I0);
typename InRightPads> const auto C = in_n_hi_wi_c_global_desc.GetLength(I3);
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_global_desc, const auto K = out_n_ho_wo_k_global_desc.GetLength(I3);
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_global_desc, const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1);
const ConvStrides& conv_strides, const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2);
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1);
const InRightPads& in_right_pads, const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2);
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const auto Y = wei_k_y_x_c_global_desc.GetLength(I1);
FloatC* __restrict__ p_out_global) const const auto X = wei_k_y_x_c_global_desc.GetLength(I2);
{
constexpr auto I0 = Number<0>{}; const auto ConvStrideH = conv_strides[I0];
constexpr auto I1 = Number<1>{}; const auto ConvStrideW = conv_strides[I1];
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto N = in_n_hi_wi_c_global_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); const auto InLeftPadH = in_left_pads[I0];
const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); const auto InLeftPadW = in_left_pads[I1];
const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1); const auto InRightPadH = in_right_pads[I0];
const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2); const auto InRightPadW = in_right_pads[I1];
const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1); assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2); ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0);
const auto Y = wei_k_y_x_c_global_desc.GetLength(I1);
const auto X = wei_k_y_x_c_global_desc.GetLength(I2); // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
const auto ConvStrideH = conv_strides[I0]; make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
const auto ConvStrideW = conv_strides[I1]; make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
const auto ConvDilationH = conv_dilations[I0]; make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto ConvDilationW = conv_dilations[I1];
// input tensor
const auto InLeftPadH = in_left_pads[I0]; const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
const auto InLeftPadW = in_left_pads[I1]; make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)),
const auto InRightPadH = in_right_pads[I0]; make_tuple(Sequence<0>{}, Sequence<1>{}),
const auto InRightPadW = in_right_pads[I1]; make_tuple(Sequence<1>{}, Sequence<0>{}));
if(!(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && // output tensor
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
InRightPadW == 0)) make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
{ make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
throw std::runtime_error("wrong! 1x1, stride 1, no padding"); make_tuple(Sequence<0>{}, Sequence<1>{}),
} make_tuple(Sequence<1>{}, Sequence<0>{}));
// weight tensor const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto GemmM0 = GemmM / Number<GemmM1>{};
// input tensor const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)), const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), out_gemmm_gemmn_global_desc,
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
// output tensor make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), // out_gemm_block_cluster_desc
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
make_tuple(Sequence<1>{}, Sequence<0>{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_global_iterator_hacks
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); // tensor
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0)) constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
{
throw std::runtime_error("wrong! GEMM size no divisible"); // hack to control index calculation when iterating over b_k_n_global tensor
} constexpr auto in_gemmk_gemmn_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{}; make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{};
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
const auto GemmM0 = GemmM / GemmM1;
const auto GemmN0 = GemmN / GemmN1; // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
transform_dynamic_tensor_descriptor( Sequence<0, 0, 0, 0, 0>{},
out_gemmm_gemmn_global_desc, Sequence<0, 0, 0, 0, 0>{},
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), Sequence<0, 0, 0, 0, 0>{}),
make_unmerge_transform(make_tuple(GemmN0, GemmN1))), make_tuple(Sequence<0, 0, 0, 0, 0>{},
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<0, 0, 0, 0, 0>{},
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_k_m_global_iterator_hacks = return make_tuple(wei_gemmk_gemmm_global_desc,
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), in_gemmk_gemmn_global_desc,
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
out_gemm_block_cluster_desc,
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; wei_gemmk_gemmm_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks,
// hack to control index calculation when iterating over b_k_n_global tensor out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
constexpr auto b_k_n_global_iterator_hacks = wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); }
constexpr auto b_k_n_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmBBlockTransferSrcScalarPerVector_GemmK,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
1,
GemmCThreadTransferDstScalarPerVector_GemmM1,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize);
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
};
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1
#define CK_DRIVER_DYNAMIC_GEMM_V1
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
const FloatAB* p_b_global,
FloatC* p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
const auto K = a_k_m_global_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// GEMM
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGlobalDesc,
BGlobalDesc,
CGlobalDesc,
CBlockClusterDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerThread,
NPerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks>;
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, true>,
integral_constant<bool, true>>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, true>,
integral_constant<bool, false>>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, false>,
integral_constant<bool, true>>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, false>,
integral_constant<bool, false>>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global,
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global,
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global,
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global,
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
return ave_time;
#endif
}
} // namespace ck
#endif
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
// TODO remove dependency on deprecated tensor descriptor // TODO remove dependency on deprecated tensor descriptor
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_adaptor.hpp"
namespace ck { namespace ck {
...@@ -44,5 +45,30 @@ __host__ __device__ constexpr auto make_cluster_descriptor( ...@@ -44,5 +45,30 @@ __host__ __device__ constexpr auto make_cluster_descriptor(
return ClusterDescriptor<Lengths, decltype(order)>{}; return ClusterDescriptor<Lengths, decltype(order)>{};
} }
#if 1
template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor_v2(
const Lengths& lengths,
ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
{
constexpr index_t ndim_low = Lengths::Size();
const auto reordered_lengths = container_reorder_given_new2old(lengths, order);
const auto low_lengths = generate_tuple(
[&](auto idim_low) { return reordered_lengths[idim_low]; }, Number<ndim_low>{});
const auto transform = make_merge_transform(low_lengths);
constexpr auto low_dim_old_top_ids = ArrangeOrder{};
constexpr auto up_dim_new_top_ids = Sequence<0>{};
return make_single_stage_tensor_adaptor(
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
}
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -1282,7 +1282,7 @@ struct DynamicFreeze ...@@ -1282,7 +1282,7 @@ struct DynamicFreeze
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const const UpIdx& idx_up) const
{ {
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_low = low_idx_; idx_low = low_idx_;
...@@ -1299,7 +1299,7 @@ struct DynamicFreeze ...@@ -1299,7 +1299,7 @@ struct DynamicFreeze
const UpIdx& idx_up_new, const UpIdx& idx_up_new,
Number<Hack>) Number<Hack>)
{ {
idx_diff_low(Number<0>{}) = index_t{Number<0>{}}; idx_diff_low(Number<0>{}) = 0;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
...@@ -1328,5 +1328,90 @@ struct DynamicFreeze ...@@ -1328,5 +1328,90 @@ struct DynamicFreeze
} }
}; };
template <typename VectorSize, typename UpLength>
struct DynamicVectorize
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
UpLengths up_lengths_;
VectorSize vector_size_;
__host__ __device__ constexpr DynamicVectorize() = default;
__host__ __device__ constexpr DynamicVectorize(const VectorSize& vector_size,
const UpLength& up_length)
: vector_size_{vector_size}, up_lengths_{make_tuple(up_length)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}];
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = Number<0>{};
idx_diff_low(I0) = vector_size_ * idx_diff_up[I0];
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value;
}
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicVectorize, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("}");
}
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -74,5 +74,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i ...@@ -74,5 +74,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
return DynamicFreeze<LowerIndex>{low_idx}; return DynamicFreeze<LowerIndex>{low_idx};
} }
template <typename VectorSize, typename UpLength>
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
const UpLength& up_length)
{
return DynamicVectorize<VectorSize, UpLength>{vector_size, up_length};
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -12,25 +12,6 @@ struct DynamicTensorCoordinate; ...@@ -12,25 +12,6 @@ struct DynamicTensorCoordinate;
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack> template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct DynamicTensorCoordinateIterator; struct DynamicTensorCoordinateIterator;
template <typename LowerDimensionIdss, typename UpperDimensionIdss>
__host__ __device__ constexpr index_t GetNumOfHiddenDimension(LowerDimensionIdss,
UpperDimensionIdss)
{
constexpr auto all_low_dim_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
constexpr auto all_up_dim_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
math::less<index_t>,
math::equal<index_t>>::type;
return unique_sort_all_dim_ids::Size();
}
// Transforms: Tuple<transforms...> // Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...> // LowerDimensionIdss : Tuple<Sequence<...>, ...>
// UpperDimensionIdss : Tuple<Sequence<...>, ...> // UpperDimensionIdss : Tuple<Sequence<...>, ...>
...@@ -374,13 +355,13 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, ...@@ -374,13 +355,13 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
// put everything together // put everything together
const auto all_transforms = container_cat(old_tensor_desc.GetTransforms(), new_transforms); const auto all_transforms = container_concat(old_tensor_desc.GetTransforms(), new_transforms);
constexpr auto all_low_dim_hidden_idss = constexpr auto all_low_dim_hidden_idss =
container_cat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss); container_concat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss);
constexpr auto all_up_dim_hidden_idss = constexpr auto all_up_dim_hidden_idss =
container_cat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss); container_concat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss);
const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
......
#ifndef CK_TENSOR_ADAPTOR_HPP
#define CK_TENSOR_ADAPTOR_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// Transforms: Tuple<transforms...>
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// BottomDimensionHiddenIds : Sequence<...>
// TopDimensionHiddenIds : Sequence<...>
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename BottomDimensionHiddenIds,
typename TopDimensionHiddenIds>
struct TensorAdaptor
{
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
__host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; }
__host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss()
{
return LowerDimensionHiddenIdss{};
}
__host__ __device__ static constexpr auto GetUpperDimensionHiddenIdss()
{
return UpperDimensionHiddenIdss{};
}
__host__ __device__ static constexpr auto GetTopDimensionHiddenIds()
{
return TopDimensionHiddenIds{};
}
__host__ __device__ static constexpr auto GetBottomDimensionHiddenIds()
{
return BottomDimensionHiddenIds{};
}
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_top) {
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top);
constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
const auto length =
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
return length;
},
Number<ndim_top_>{});
// TODO: make container_reduce support tuple of Number and index_t
return container_reduce(lengths, math::multiplies_v2{}, Number<1>{});
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number<IDim>)
{
constexpr auto idim_top = Number<IDim>{};
constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top);
index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;
static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == idim_hidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
return make_tuple(itran_found, idim_up_found, found);
}
__host__ __device__ static constexpr index_t GetNumOfBottomDimension()
{
return BottomDimensionHiddenIds::Size();
}
__host__ __device__ static constexpr index_t GetNumOfTopDimension()
{
return TopDimensionHiddenIds::Size();
}
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{
constexpr auto all_low_dim_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
LowerDimensionHiddenIdss{});
constexpr auto all_up_dim_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
UpperDimensionHiddenIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
math::less<index_t>,
math::equal<index_t>>::type;
return unique_sort_all_dim_ids::Size();
}
constexpr static index_t ntransform_ = GetNumOfTransform();
constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension();
constexpr static index_t ndim_bottom_ = GetNumOfBottomDimension();
constexpr static index_t ndim_top_ = GetNumOfTopDimension();
using HiddenIndex = MultiIndex<ndim_hidden_>;
using BottomIndex = MultiIndex<ndim_bottom_>;
using TopIndex = MultiIndex<ndim_top_>;
// may be index_t or Number<>
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
public:
__host__ __device__ constexpr TensorAdaptor() = default;
__host__ __device__ constexpr TensorAdaptor(const Transforms& transforms)
: transforms_{transforms}, element_size_{InitializeElementSize(transforms)}
{
static_assert(Transforms::Size() == ntransform_ &&
LowerDimensionHiddenIdss::Size() == ntransform_ &&
UpperDimensionHiddenIdss::Size() == ntransform_,
"wrong! inconsistent # of transformations");
// TODO check dependency of dimensions is valid
}
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
static_assert(TopIdx::Size() == TopDimensionHiddenIds::Size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = GetNumOfTransform();
constexpr index_t ndim_hidden = GetNumOfHiddenDimension();
MultiIndex<ndim_hidden> idx_hidden;
// initialize uppest index
set_container_subset(idx_hidden, GetTopDimensionHiddenIds(), idx_top);
// calculate hidden index
static_for<ntransform, 0, -1>{}([&](auto itran_p1) {
auto itran = itran_p1 - Number<1>{};
const auto& tran = GetTransforms().At(itran);
constexpr auto dims_low = GetLowerDimensionHiddenIdss().At(itran);
constexpr auto dims_up = GetUpperDimensionHiddenIdss().At(itran);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
MultiIndex<dims_low.Size()> idx_low;
tran.CalculateLowerIndex(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
}
__host__ __device__ void Print() const
{
printf("{");
printf("TensorAdaptor, ");
static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: ");
transforms_[i].Print();
printf("LowerDimensionHiddenIds:");
LowerDimensionHiddenIdss{}.At(i).Print();
printf("UpperDimensionHiddenIds:");
UpperDimensionHiddenIdss{}.At(i).Print();
});
printf("BottomDimensionHiddenIds:");
BottomDimensionHiddenIds::Print();
printf("TopDimensionHiddenIds:");
TopDimensionHiddenIds::Print();
printf("}");
}
private:
Transforms transforms_;
ElementSize element_size_;
};
template <typename TensorAdaptor0, typename TensorAdaptor1>
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
const TensorAdaptor1& adaptor1)
{
static_assert(TensorAdaptor0::GetNumOfTopDimension() ==
TensorAdaptor1::GetNumOfBottomDimension(),
"wrong!");
// all_transforms = transform0 + transform1
const auto all_transforms =
container_concat(adaptor0.GetTransforms(), adaptor1.GetTransforms());
// shift
constexpr index_t adaptor0_max_hidden_id = [&]() {
index_t adaptor0_max_hidden_id = NumericLimits<index_t>::Min();
static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low =
TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension();
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
adaptor0_max_hidden_id =
math::max(adaptor0_max_hidden_id,
TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value);
});
constexpr index_t ndim_up =
TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension();
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor0_max_hidden_id =
math::max(adaptor0_max_hidden_id,
TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
});
});
return adaptor0_max_hidden_id;
}();
constexpr index_t adaptor1_min_hidden_id = [&]() {
index_t adaptor1_min_hidden_id = NumericLimits<index_t>::Max();
static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low =
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension();
// get the min of all lower dimenions, but not bottom dimension (because their id will
// be matched with top id from adaptor0)
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
constexpr index_t low_dim_hidden_id =
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value;
bool is_bottom_dim = false;
static_for<0, TensorAdaptor1::GetNumOfBottomDimension(), 1>{}([&](auto i) {
if constexpr(low_dim_hidden_id ==
TensorAdaptor1::GetBottomDimensionHiddenIds()[i])
{
is_bottom_dim = true;
}
});
if(!is_bottom_dim)
{
adaptor1_min_hidden_id = math::min(adaptor1_min_hidden_id, low_dim_hidden_id);
}
});
constexpr index_t ndim_up =
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension();
// get the min of all upper dimensions
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor1_min_hidden_id =
math::min(adaptor1_min_hidden_id,
TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
});
});
return adaptor1_min_hidden_id;
}();
constexpr index_t adaptor1_hidden_id_shift =
adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension();
// all_low_dim_hidden_idss =
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
constexpr auto low_dim_hidden_idss_1 = generate_tuple(
// generate sequence of ids for a transform
[&](auto itran) {
constexpr auto ndim_low_1 = TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran].Size();
constexpr auto low_dim_hidden_ids_1 =
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran];
// sequence in, sequence out
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
{
auto low_dim_hidden_ids_1_mod = to_multi_index(low_dim_hidden_ids_1);
// shift hidden id so every dim id is unique
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
low_dim_hidden_ids_1_mod(idim_low_1) += adaptor1_hidden_id_shift;
});
// match hidden id
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) {
// if this low dim is bottom dim, then do id matching
if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1])
{
low_dim_hidden_ids_1_mod(idim_low_1) =
TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1];
}
});
});
return low_dim_hidden_ids_1_mod;
}
();
return generate_sequence_v2(
[&](auto i) constexpr { return Number<low_dim_hidden_ids_1_mod[i]>{}; },
Number<ndim_low_1>{});
},
Number<TensorAdaptor1::GetNumOfTransform()>{});
constexpr auto all_low_dim_hidden_idss =
container_concat(TensorAdaptor0::GetLowerDimensionHiddenIdss(), low_dim_hidden_idss_1);
// all_up_dim_hidden_idss =
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
constexpr auto up_dim_hidden_idss_1 = generate_tuple(
// generate sequence of ids for a transform
[&](auto itran) {
constexpr auto ndim_up_1 = TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran].Size();
constexpr auto up_dim_hidden_ids_1 =
TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran];
// sequence in, constexpr tuple out
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
{
auto up_dim_hidden_ids_1_mod = to_multi_index(up_dim_hidden_ids_1);
// shift hidden id
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
up_dim_hidden_ids_1_mod(idim_up_1) += adaptor1_hidden_id_shift;
});
return up_dim_hidden_ids_1_mod;
}
();
// constexpr tuple to sequence
return generate_sequence_v2(
[&](auto i) constexpr { return Number<up_dim_hidden_ids_1_mod[i]>{}; },
Number<ndim_up_1>{});
},
Number<TensorAdaptor1::GetNumOfTransform()>{});
constexpr auto all_up_dim_hidden_idss =
container_concat(TensorAdaptor0::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss_1);
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::GetBottomDimensionHiddenIds();
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr auto top_dim_hidden_ids =
TensorAdaptor1::GetTopDimensionHiddenIds() + Number<adaptor1_hidden_id_shift>{};
// put everything together
return TensorAdaptor<remove_cv_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
remove_cv_t<decltype(top_dim_hidden_ids)>>{all_transforms};
}
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss)
{
constexpr index_t ntransform = Transforms::Size();
static_assert(LowerDimensionOldTopIdss::Size() == ntransform &&
UpperDimensionNewTopIdss::Size() == ntransform,
"wrong!");
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr auto all_low_dim_old_top_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
LowerDimensionOldTopIdss{});
constexpr auto all_up_dim_new_top_ids =
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
UpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
"wrong!");
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size();
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size();
// low_dim_hidden_idss
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr auto up_dim_hidden_idss = generate_tuple(
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number<ndim_old_top>{}; },
Number<ntransform>{});
// bottom_dim_hidden_ids
constexpr auto bottom_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
// top_dim_hidden_ids
constexpr auto top_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number<ndim_old_top>{};
return TensorAdaptor<remove_cv_t<Transforms>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
}
template <typename X,
typename... Xs,
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
{
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
}
} // namespace ck
#endif
...@@ -67,26 +67,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -67,26 +67,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_id = const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id()); make_multi_index(get_thread_local_1d_id()));
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{}; const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
threadwise_transfer_.SetSrcSliceOrigin(src_desc, threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_id_begin); src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc, threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_id_begin); dst_block_slice_origin + thread_data_idx_begin);
} }
} }
__device__ static constexpr auto CalculateThreadDataBegin()
{
const auto thread_cluster_id =
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
return thread_cluster_id * ThreadSliceLengths{};
}
template <typename SrcIteratorHacks> template <typename SrcIteratorHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void RunRead(const SrcDesc& src_desc,
const SrcData* p_src, const SrcData* p_src,
...@@ -141,8 +133,9 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -141,8 +133,9 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
} }
} }
private:
static constexpr auto thread_cluster_desc_ = static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v3<ThreadSliceLengths, ThreadwiseDynamicTensorSliceTransfer_v3<ThreadSliceLengths,
......
...@@ -7,62 +7,175 @@ ...@@ -7,62 +7,175 @@
namespace ck { namespace ck {
// C[M, N] += transpose(A[K, M]) * B[K, N] // C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread // A and B are visable to the whole block, C is distributed among each thread
// Assume: // Assume:
// 1. A: // 1. A:
// 1. BlockMatrixA is known at compile-time // 1. ABlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer // 2. ABlockBuffer is DynamicBuffer
// 2. B: // 2. B:
// 1. BlockMatrixA is known at compile-time // 1. ABlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer // 2. BBlockBuffer is DynamicBuffer
// 3. C: // 3. C:
// 1. ThreadMatrixC is known at compile-time // 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer // 2. CThreadBuffer is StaticBuffer
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename BlockMatrixA, typename ABlockDesc,
typename BlockMatrixB, typename BBlockDesc,
typename ThreadMatrixC, typename CThreadDesc,
index_t MPerThreadSubC, index_t M1PerThread,
index_t NPerThreadSubC, index_t N1PerThread,
index_t KPerThreadLoop, index_t KPerThread,
index_t MLevel0ThreadCluster, index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster, index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster, index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster, index_t NLevel1ThreadCluster,
index_t ThreadGemmADataPerRead_M, index_t AThreadCopyScalarPerVector_M1,
index_t ThreadGemmBDataPerRead_N, index_t BThreadCopyScalarPerVector_N1,
typename std::enable_if<BlockMatrixA::IsKnownAtCompileTime() && typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
ThreadMatrixC::IsKnownAtCompileTime(), CThreadDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
{ {
struct MatrixIndex using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
public:
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1()
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster *
NLevel0ThreadCluster * NLevel1ThreadCluster,
"wrong! blocksize and cluster size not consistent");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
}
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
{
constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
// 4-d data space into 4-d thread space
constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
make_tuple(make_vectorize_transform(M0, 1),
make_vectorize_transform(M1PerThread, M1 / M1PerThread),
make_vectorize_transform(N0, 1),
make_vectorize_transform(N1PerThread, N1 / N1PerThread)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// thread position 4-d thread space
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)),
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
// 4-d thread space to 1-d thread space
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster,
NLevel1ThreadCluster,
MLevel0ThreadCluster,
NLevel0ThreadCluster))),
make_tuple(Sequence<0, 2, 1, 3>{}),
make_tuple(Sequence<0>{}));
constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2);
return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{ {
index_t row; auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize());
index_t col; auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize());
};
constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
FloatB,
FloatC,
decltype(a_thread_desc_),
decltype(b_thread_desc_),
CThreadDesc,
Sequence<KPerThread>,
Sequence<M0_, M1PerThread>,
Sequence<N0_, N1PerThread>>{};
constexpr index_t K = ABlockDesc{}.GetLength(I0);
static_for<0, K, KPerThread>{}([&](auto k) {
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
});
}
private: private:
static constexpr auto a_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1);
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<0>{}))); static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1);
static constexpr auto b_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( // A[K, M0, M1]
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<1>{}))); static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<M0_>{}, Number<M1PerThread>{}));
// B[K, N0, N1]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<N0_>{}, Number<N1PerThread>{}));
using AThreadCopy = using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA, ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA, FloatA,
BlockMatrixA, ABlockDesc,
decltype(a_thread_mtx_desc_), decltype(a_thread_desc_),
Sequence<KPerThreadLoop, MPerThreadSubC>, Sequence<KPerThread, M0_, M1PerThread>,
Sequence<0, 1>, Sequence<0, 1, 2>,
1, 2,
ThreadGemmADataPerRead_M, AThreadCopyScalarPerVector_M1,
AddressSpace::Generic, AddressSpace::Generic,
AddressSpace::Vgpr, AddressSpace::Vgpr,
1>; 1>;
...@@ -70,307 +183,326 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -70,307 +183,326 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
using BThreadCopy = using BThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB, ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB, FloatB,
BlockMatrixB, BBlockDesc,
decltype(b_thread_mtx_desc_), decltype(b_thread_desc_),
Sequence<KPerThreadLoop, NPerThreadSubC>, Sequence<KPerThread, N0_, N1PerThread>,
Sequence<0, 1>, Sequence<0, 1, 2>,
1, 2,
ThreadGemmBDataPerRead_N, BThreadCopyScalarPerVector_N1,
AddressSpace::Generic, AddressSpace::Generic,
AddressSpace::Vgpr, AddressSpace::Vgpr,
1>; 1>;
MatrixIndex c_thread_begin_mtx_idx_; CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_;
};
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. ABlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc,
typename BBlockDesc,
typename CThreadDesc,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t AThreadCopyScalarPerVector_M1,
index_t BThreadCopyScalarPerVector_N1,
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
public: public:
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1() __device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, : c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.row)}, a_thread_copy_{
b_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.col)} make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
{ {
static_assert(BlockMatrixA::IsKnownAtCompileTime() && static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() && CThreadDesc::IsKnownAtCompileTime(),
ThreadMatrixC::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{}; static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster *
constexpr auto I1 = Number<1>{}; NLevel0ThreadCluster * NLevel1ThreadCluster,
"wrong! blocksize and cluster size not consistent");
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster * static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed // TODO: remove this restriction
constexpr index_t N = BlockMatrixB{}.GetLength(I1); static_assert(ABlockDesc{}.GetLength(I1) == 2 && BBlockDesc{}.GetLength(I1) == 2 &&
CThreadDesc{}.GetLength(I0) == 2 && CThreadDesc{}.GetLength(I2) == 2,
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 && "wrong");
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
"wrong! Cannot evenly divide work among");
static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] &&
ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1],
"wrong! ThreadMatrixC lengths is wrong");
}
__device__ static constexpr auto GetThreadMatrixCLengths()
{
constexpr auto I1 = Number<1>{};
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
constexpr index_t MRepeat =
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
constexpr index_t NRepeat =
N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster);
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
} }
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) __device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
{ {
constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster; constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
index_t level1_id = thread_id / ThreadPerLevel0Cluster; constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
index_t level1_m_id = level1_id / NLevel1ThreadCluster; constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
index_t level1_n_id = level1_id % NLevel1ThreadCluster;
// 4-d data space into 4-d thread space
index_t level0_id = thread_id % ThreadPerLevel0Cluster; constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
index_t level0_m_id = level0_id / NLevel0ThreadCluster; make_tuple(make_vectorize_transform(M0, 1),
index_t level0_n_id = level0_id % NLevel0ThreadCluster; make_vectorize_transform(M1PerThread, M1 / M1PerThread),
make_vectorize_transform(N0, 1),
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster; make_vectorize_transform(N1PerThread, N1 / N1PerThread)),
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster; make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; // thread position 4-d thread space
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)),
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
// 4-d thread space to 1-d thread space
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster,
NLevel1ThreadCluster,
MLevel0ThreadCluster,
NLevel0ThreadCluster))),
make_tuple(Sequence<0, 2, 1, 3>{}),
make_tuple(Sequence<0>{}));
constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2);
return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run_pipelined_2x2(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>, auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize());
remove_cv_t<remove_reference_t<FloatA>>>::value && auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize());
is_same<remove_cv_t<remove_reference_t<typename BBlockBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto threadwise_gemm =
constexpr auto I1 = Number<1>{}; ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
FloatB,
constexpr auto a_block_mtx = BlockMatrixA{}; FloatC,
constexpr auto b_block_mtx = BlockMatrixB{}; decltype(a_thread_desc_),
constexpr auto c_thread_mtx_desc = ThreadMatrixC{}; decltype(b_thread_desc_),
CThreadDesc,
constexpr auto K = a_block_mtx.GetLength(I0); Sequence<KPerThread>,
Sequence<1, M1PerThread>,
constexpr auto MPerThread = c_thread_mtx_desc.GetLength(I0); Sequence<1, N1PerThread>>{};
constexpr auto NPerThread = c_thread_mtx_desc.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert(MRepeat == 2 && NRepeat == 2, "wrong! only support 2x2 pipeline");
// thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}),
make_tuple(Number<MPerThread>{}, Number<1>{}));
constexpr auto b_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
constexpr auto c_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_mtx_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_mtx_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1<FloatA, constexpr index_t K = ABlockDesc{}.GetLength(I0);
FloatB,
FloatC,
decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I0), make_tuple(I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_mtx_desc_, a_thread_desc_,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
a_thread_buf); a_thread_buf);
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I0), make_tuple(I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_mtx_desc_, b_thread_desc_,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
b_thread_buf); b_thread_buf);
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, Number<NPerLevel1Cluster>{}), make_tuple(I0, I1, I0),
b_block_buf, b_block_buf,
b_thread_mtx_desc_, b_thread_desc_,
make_tuple(I0, Number<NPerThreadSubC>{}), make_tuple(I0, I1, I0),
b_thread_buf); b_thread_buf);
// read A_sub_1 // read A_sub_1
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, Number<MPerLevel1Cluster>{}), make_tuple(I0, I1, I0),
a_block_buf, a_block_buf,
a_thread_mtx_desc_, a_thread_desc_,
make_tuple(I0, Number<MPerThreadSubC>{}), make_tuple(I0, I1, I0),
a_thread_buf); a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
b_thread_buf, b_thread_buf,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
c_thread_buf, c_thread_buf,
make_tuple(I0, I0)); make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1 // C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
b_thread_buf, b_thread_buf,
make_tuple(I0, Number<NPerThreadSubC>{}), make_tuple(I0, I1, I0),
c_thread_buf, c_thread_buf,
make_tuple(I0, Number<NPerThreadSubC>{})); make_tuple(I0, I0, I1, I0));
// loop over rest of k // loop over rest of k
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) { static_for<KPerThread, K, KPerThread>{}([&](auto k) {
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0), make_tuple(k, I0, I0),
a_block_buf, a_block_buf,
a_thread_mtx_desc_, a_thread_desc_,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
a_thread_buf); a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, Number<MPerThreadSubC>{}), make_tuple(I0, I1, I0),
b_thread_buf, b_thread_buf,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
c_thread_buf, c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, I0)); make_tuple(I1, I0, I0, I0));
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0), make_tuple(k, I0, I0),
b_block_buf, b_block_buf,
b_thread_mtx_desc_, b_thread_desc_,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
b_thread_buf); b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1 // C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, Number<MPerThreadSubC>{}), make_tuple(I0, I1, I0),
b_thread_buf, b_thread_buf,
make_tuple(I0, Number<NPerThreadSubC>{}), make_tuple(I0, I1, I0),
c_thread_buf, c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{})); make_tuple(I1, I0, I1, I0));
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, Number<NPerLevel1Cluster>{}), make_tuple(k, I1, I0),
b_block_buf, b_block_buf,
b_thread_mtx_desc_, b_thread_desc_,
make_tuple(I0, Number<NPerThreadSubC>{}), make_tuple(I0, I1, I0),
b_thread_buf); b_thread_buf);
// read A_sub_1 // read A_sub_1
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, Number<MPerLevel1Cluster>{}), make_tuple(k, I1, I0),
a_block_buf, a_block_buf,
a_thread_mtx_desc_, a_thread_desc_,
make_tuple(I0, Number<MPerThreadSubC>{}), make_tuple(I0, I1, I0),
a_thread_buf); a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
b_thread_buf, b_thread_buf,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
c_thread_buf, c_thread_buf,
make_tuple(I0, I0)); make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1 // C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
b_thread_buf, b_thread_buf,
make_tuple(I0, Number<NPerThreadSubC>{}), make_tuple(I0, I1, I0),
c_thread_buf, c_thread_buf,
make_tuple(I0, Number<NPerThreadSubC>{})); make_tuple(I0, I0, I1, I0));
}); });
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, Number<MPerThreadSubC>{}), make_tuple(I0, I1, I0),
b_thread_buf, b_thread_buf,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
c_thread_buf, c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, I0)); make_tuple(I1, I0, I0, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1 // C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, Number<MPerThreadSubC>{}), make_tuple(I0, I1, I0),
b_thread_buf, b_thread_buf,
make_tuple(I0, Number<NPerThreadSubC>{}), make_tuple(I0, I1, I0),
c_thread_buf, c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{})); make_tuple(I1, I0, I1, I0));
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> private:
__device__ void Run(const ABlockBuffer& a_block_buf, static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1);
const BBlockBuffer& b_block_buf, static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1);
CThreadBuffer& c_thread_buf) const
{ // A[K, M0, M1]
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
constexpr auto I0 = Number<0>{}; make_tuple(Number<KPerThread>{}, Number<M0_>{}, Number<M1PerThread>{}));
constexpr auto I1 = Number<1>{};
// B[K, N0, N1]
constexpr index_t MPerThread = ThreadMatrixC{}.GetLength(I0); static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
constexpr index_t NPerThread = ThreadMatrixC{}.GetLength(I1); make_tuple(Number<KPerThread>{}, Number<N0_>{}, Number<N1PerThread>{}));
constexpr index_t MRepeat = MPerThread / MPerThreadSubC; using AThreadCopy =
constexpr index_t NRepeat = NPerThread / NPerThreadSubC; ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
if constexpr(MRepeat == 2 && NRepeat == 2) ABlockDesc,
{ decltype(a_thread_desc_),
Run_pipelined_2x2(a_block_buf, b_block_buf, c_thread_buf); Sequence<KPerThread, 1, M1PerThread>,
} Sequence<0, 1, 2>,
else 2,
{ AThreadCopyScalarPerVector_M1,
Run_naive(a_block_buf, b_block_buf, c_thread_buf); AddressSpace::Generic,
} AddressSpace::Vgpr,
#else 1>;
Run_naive(a_block_buf, b_block_buf, c_thread_buf);
#endif using BThreadCopy =
} ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<KPerThread, 1, N1PerThread>,
Sequence<0, 1, 2>,
2,
BThreadCopyScalarPerVector_N1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
}; };
} // namespace ck } // namespace ck
#endif #endif
...@@ -12,7 +12,36 @@ ...@@ -12,7 +12,36 @@
namespace ck { namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA,
typename BGlobalDesc,
typename FloatB,
typename CGlobalDesc,
typename FloatC,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc,
const FloatA* __restrict__ p_a_global,
const BGlobalDesc b_k_n_global_desc,
const FloatB* __restrict__ p_b_global,
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
const CBlockClusterDesc c_block_cluster_desc)
{
GridwiseGemm{}.Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer // pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to // __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization // non-modifiable parameter address space, so compiler can enable corresponding optimization
...@@ -23,16 +52,18 @@ template <typename GridwiseGemm, ...@@ -23,16 +52,18 @@ template <typename GridwiseGemm,
typename FloatB, typename FloatB,
typename CGlobalDesc, typename CGlobalDesc,
typename FloatC, typename FloatC,
typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc, __global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc,
const FloatA* __restrict__ p_a_global, const FloatA* __restrict__ p_a_global,
const void __CONSTANT__* p_b_k_n_global_desc, const void __CONSTANT__* p_b_k_n_global_desc,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global) FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_c_block_cluster_desc)
{ {
// first cast void __CONSTANT__* to void* // first cast void __CONSTANT__ void* to void*
// second cast void* to Desc* // second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4) // the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m_global_desc = const auto a_k_m_global_desc =
...@@ -42,12 +73,16 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl ...@@ -42,12 +73,16 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl
const auto c_m0_m1_n0_n1_global_desc = const auto c_m0_m1_n0_n1_global_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_n0_n1_global_desc); *reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_n0_n1_global_desc);
const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm{}.Run(a_k_m_global_desc, GridwiseGemm{}.Run(a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc, b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
...@@ -61,6 +96,7 @@ template <index_t BlockSize, ...@@ -61,6 +96,7 @@ template <index_t BlockSize,
typename AGlobalDesc, typename AGlobalDesc,
typename BGlobalDesc, typename BGlobalDesc,
typename CGlobalDesc, typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
...@@ -131,37 +167,30 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -131,37 +167,30 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto K = a_k_m_global_desc.GetLength(I0); const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1); const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1); const auto N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N] // divide block work by [M, N]
#if 0 const auto block_work_idx =
const auto m_block_work_num = M / Number<MPerBlock>{}; c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const auto n_block_work_num = N / Number<NPerBlock>{};
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#else
// Hack: this force result into SGPR
const index_t m_block_work_num = __builtin_amdgcn_readfirstlane(M / MPerBlock);
const index_t n_block_work_num = __builtin_amdgcn_readfirstlane(N / NPerBlock);
const index_t m_block_work_id = // HACK: this force m/n_block_data_idx_on_global into SGPR
__builtin_amdgcn_readfirstlane(get_block_1d_id() / n_block_work_num); const index_t m_block_data_idx_on_global =
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num; __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
#endif
const index_t m_block_data_on_global = m_block_work_id * MPerBlock; const index_t n_block_data_idx_on_global =
const index_t n_block_data_on_global = n_block_work_id * NPerBlock; __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
...@@ -204,7 +233,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -204,7 +233,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
a_k_m_global_desc, a_k_m_global_desc,
make_multi_index(0, m_block_data_on_global), make_multi_index(0, m_block_data_idx_on_global),
a_k_m_block_desc, a_k_m_block_desc,
make_multi_index(0, 0)); make_multi_index(0, 0));
...@@ -233,7 +262,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -233,7 +262,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
b_k_n_global_desc, b_k_n_global_desc,
make_multi_index(0, n_block_data_on_global), make_multi_index(0, n_block_data_idx_on_global),
b_k_n_block_desc, b_k_n_block_desc,
make_multi_index(0, 0)); make_multi_index(0, 0));
...@@ -251,28 +280,45 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -251,28 +280,45 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
// TODO:: more elegent way of defining c_thread_mtx a_k_m_block_desc,
constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2( make_tuple(
make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{})); make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
b_k_n_block_desc,
make_tuple(
make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v1r1<BlockSize, BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_k_m_block_desc), decltype(a_k_m0_m1_block_desc),
decltype(b_k_n_block_desc), decltype(b_k_n0_n1_block_desc),
decltype(c_m0m1_n0n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
MPerThread, MPerThread,
NPerThread, NPerThread,
KPerThread, KPerThread,
MLevel0Cluster, MLevel0Cluster,
NLevel0Cluster, NLevel0Cluster,
MLevel1Cluster, MLevel1Cluster,
NLevel1Cluster, NLevel1Cluster,
MPerThread, MPerThread,
NPerThread>{}; NPerThread>{};
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
...@@ -286,12 +332,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -286,12 +332,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// register allocation for output // register allocation for output
auto c_thread_buf = auto c_thread_buf =
make_static_buffer<FloatAcc>(c_m0m1_n0n1_thread_desc.GetElementSpaceSize()); make_static_buffer<FloatAcc>(c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_m0m1_n0n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
Sequence<MRepeat * MPerThread, NRepeat * NPerThread>>{} Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
.Run(c_m0m1_n0n1_thread_desc, make_tuple(I0, I0), c_thread_buf, FloatAcc{0}); .Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
...@@ -427,30 +473,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -427,30 +473,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}; constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}; constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<MRepeat>{},
Number<MPerThread>{},
Number<NRepeat>{},
Number<NPerThread>{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
constexpr auto tmp = make_unmerge_transform(make_tuple( const auto c_thread_data_idx_on_block =
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{})); blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
...@@ -465,11 +492,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -465,11 +492,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
AddressSpace::Global, AddressSpace::Global,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>(c_m0_m1_n0_n1_global_desc, true>{
make_multi_index(m_thread_data_on_global / M1, c_m0_m1_n0_n1_global_desc,
m_thread_data_on_global % M1, make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0],
n_thread_data_on_global / N1, c_thread_data_idx_on_block[I1],
n_thread_data_on_global % N1)) n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2],
c_thread_data_idx_on_block[I3])}
.Run(c_m0_m1_n0_n1_thread_desc, .Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
...@@ -486,6 +514,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -486,6 +514,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
...@@ -499,6 +528,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -499,6 +528,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc,
p_shared_block, p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
......
...@@ -1376,6 +1376,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1376,6 +1376,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
{ {
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0, "wrong!");
} }
template <typename SrcRefToOriginDisplacement, template <typename SrcRefToOriginDisplacement,
......
...@@ -140,5 +140,103 @@ struct ThreadwiseGemm_km_kn_mn_v1r1 ...@@ -140,5 +140,103 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
} }
}; };
// C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1]
// Tensor element can be vectorized data
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc,
typename CDesc,
typename KLengths,
typename MLengths,
typename NLengths,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
{
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1()
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
// TODO remove this restriction
static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto K = KLengths{}[I0];
constexpr auto M0 = MLengths{}[I0];
constexpr auto M1 = MLengths{}[I1];
constexpr auto N0 = NLengths{}[I0];
constexpr auto N1 = NLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M0, 1>{}([&](auto m0) {
static_for<0, M1, 1>{}([&](auto m1) {
static_for<0, N0, 1>{}([&](auto n0) {
static_for<0, N1, 1>{}([&](auto n1) {
constexpr index_t a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_multi_index(k, m0, m1));
constexpr index_t b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_multi_index(k, n0, n1));
constexpr index_t c_offset = CDesc{}.CalculateOffset(
c_origin_idx + make_multi_index(m0, m1, n0, n1));
amd_assembly_inner_product(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{}));
});
});
});
});
});
}
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "container_helper.hpp" #include "container_helper.hpp"
#include "statically_indexed_array.hpp" #include "statically_indexed_array.hpp"
#include "container_element_picker.hpp" #include "container_element_picker.hpp"
#include "data_type.hpp"
#include "float_type.hpp" #include "float_type.hpp"
#include "buffer.hpp" #include "buffer.hpp"
#include "functional.hpp" #include "functional.hpp"
......
...@@ -20,7 +20,8 @@ struct ContainerElementPicker ...@@ -20,7 +20,8 @@ struct ContainerElementPicker
__host__ __device__ constexpr ContainerElementPicker(Arr& array) : mArray{array} __host__ __device__ constexpr ContainerElementPicker(Arr& array) : mArray{array}
{ {
constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{}); constexpr index_t imax =
reduce_on_sequence(Picks{}, math::maximize<index_t>{}, Number<0>{});
static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
} }
...@@ -85,7 +86,8 @@ struct ConstantContainerElementPicker ...@@ -85,7 +86,8 @@ struct ConstantContainerElementPicker
__host__ __device__ constexpr ConstantContainerElementPicker(const Arr& array) : mArray{array} __host__ __device__ constexpr ConstantContainerElementPicker(const Arr& array) : mArray{array}
{ {
constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{}); constexpr index_t imax =
reduce_on_sequence(Picks{}, math::maximize<index_t>{}, Number<0>{});
static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
} }
......
...@@ -26,13 +26,13 @@ __host__ __device__ constexpr auto container_push_back(const Array<TData, NSize> ...@@ -26,13 +26,13 @@ __host__ __device__ constexpr auto container_push_back(const Array<TData, NSize>
template <typename... Ts, typename T> template <typename... Ts, typename T>
__host__ __device__ constexpr auto container_push_front(const Tuple<Ts...>& a, const T& x) __host__ __device__ constexpr auto container_push_front(const Tuple<Ts...>& a, const T& x)
{ {
return container_cat(make_tuple(x), a); return container_concat(make_tuple(x), a);
} }
template <typename... Ts, typename T> template <typename... Ts, typename T>
__host__ __device__ constexpr auto container_push_back(const Tuple<Ts...>& a, const T& x) __host__ __device__ constexpr auto container_push_back(const Tuple<Ts...>& a, const T& x)
{ {
return container_cat(a, make_tuple(x)); return container_concat(a, make_tuple(x));
} }
template <typename TData, index_t NSize, index_t... IRs> template <typename TData, index_t NSize, index_t... IRs>
...@@ -158,6 +158,7 @@ __host__ __device__ constexpr auto container_reduce_impl( ...@@ -158,6 +158,7 @@ __host__ __device__ constexpr auto container_reduce_impl(
} }
// rocm-4.1 compiler would crash for recursive lambda // rocm-4.1 compiler would crash for recursive lambda
// container reduce with initial value
template <typename Container, template <typename Container,
typename Reduce, typename Reduce,
typename Init, typename Init,
...@@ -299,27 +300,27 @@ container_reverse_inclusive_scan(const Tuple<Xs...>& x, Reduce f, TData init) ...@@ -299,27 +300,27 @@ container_reverse_inclusive_scan(const Tuple<Xs...>& x, Reduce f, TData init)
} }
template <typename X, typename... Ys> template <typename X, typename... Ys>
__host__ __device__ constexpr auto container_cat(const X& x, const Ys&... ys) __host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys)
{ {
return container_cat(x, container_cat(ys...)); return container_concat(x, container_concat(ys...));
} }
template <typename T, index_t NX, index_t NY> template <typename T, index_t NX, index_t NY>
__host__ __device__ constexpr auto container_cat(const Array<T, NX>& ax, const Array<T, NY>& ay) __host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay)
{ {
return unpack2( return unpack2(
[&](auto&&... zs) { return make_array(std::forward<decltype(zs)>(zs)...); }, ax, ay); [&](auto&&... zs) { return make_array(std::forward<decltype(zs)>(zs)...); }, ax, ay);
} }
template <typename... X, typename... Y> template <typename... X, typename... Y>
__host__ __device__ constexpr auto container_cat(const Tuple<X...>& tx, const Tuple<Y...>& ty) __host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{ {
return unpack2( return unpack2(
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty); [&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
} }
template <typename Container> template <typename Container>
__host__ __device__ constexpr auto container_cat(const Container& x) __host__ __device__ constexpr auto container_concat(const Container& x)
{ {
return x; return x;
} }
......
#ifndef CK_DATA_TYPE_HPP
#define CK_DATA_TYPE_HPP
namespace ck {
template <typename T>
struct NumericLimits;
template <>
struct NumericLimits<int32_t>
{
__host__ __device__ static constexpr int32_t Min()
{
return std::numeric_limits<int32_t>::min();
}
__host__ __device__ static constexpr int32_t Max()
{
return std::numeric_limits<int32_t>::max();
}
};
} // namespace ck
#endif
...@@ -43,11 +43,17 @@ struct multiplies_v2 ...@@ -43,11 +43,17 @@ struct multiplies_v2
}; };
template <class T> template <class T>
struct maxer struct maximize
{ {
__host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; } __host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
}; };
template <class T>
struct minimize
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
};
template <class T> template <class T>
struct integer_divide_ceiler struct integer_divide_ceiler
{ {
......
...@@ -46,6 +46,7 @@ void launch_kernel(F kernel, ...@@ -46,6 +46,7 @@ void launch_kernel(F kernel,
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel(F kernel, float launch_and_time_kernel(F kernel,
int nrepeat,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
std::size_t lds_byte, std::size_t lds_byte,
...@@ -54,15 +55,32 @@ float launch_and_time_kernel(F kernel, ...@@ -54,15 +55,32 @@ float launch_and_time_kernel(F kernel,
{ {
KernelTimer timer; KernelTimer timer;
timer.Start(); printf("%s: block_dim {%d, %d, %d}, grid_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up\n");
// warm up
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
timer.End(); printf("Start running %d times...\n", nrepeat);
timer.Start();
hipGetLastError(); for(int i = 0; i < nrepeat; ++i)
{
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
}
return timer.GetElapsedTime(); timer.End();
return timer.GetElapsedTime() / nrepeat;
} }
#elif CK_DEVICE_BACKEND_NVIDIA #elif CK_DEVICE_BACKEND_NVIDIA
......
...@@ -29,8 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -29,8 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
{ {
using namespace ck; using namespace ck;
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw" std::cout << __func__ << std::endl;
<< std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
...@@ -459,50 +468,94 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -459,50 +468,94 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#endif #endif
constexpr auto conv_driver = constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
const auto descs =
#if 1 #if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad
#elif 0 #elif 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad
#elif 1 #else
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1
#endif #endif
<BlockSize, <GemmMPerBlock, GemmNPerBlock, GemmM1, GemmN1>(wei_k_c_y_x_desc,
typename vector_type<TInWei, InWeiVectorSize>::type, in_n_c_hi_wi_desc,
TAcc, out_n_k_ho_wo_desc,
TOut, conv_strides,
GemmMPerBlock, conv_dilations,
GemmNPerBlock, in_left_pads,
GemmKPerBlock, in_right_pads);
GemmMPerThread,
GemmNPerThread, for(index_t i = 0; i < 5; ++i)
GemmKPerThread, {
GemmMLevel0Cluster, float ave_time = launch_kernel_dynamic_gemm_v1<
GemmNLevel0Cluster, BlockSize,
GemmMLevel1Cluster, typename vector_type<TInWei, InWeiVectorSize>::type,
GemmNLevel1Cluster, TAcc,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, TOut,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, InMemoryDataOperation::Set,
GemmABlockTransferSrcScalarPerVector_GemmK, decltype(descs[I0]),
GemmABlockTransferDstScalarPerVector_GemmM, decltype(descs[I1]),
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, decltype(descs[I2]),
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, decltype(descs[I3]),
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmMPerBlock,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmNPerBlock,
GemmCThreadTransferDstScalarPerVector_GemmN1>{}; GemmKPerBlock,
GemmMPerThread,
conv_driver.Run(wei_k_c_y_x_desc, GemmNPerThread,
in_n_c_hi_wi_desc, GemmKPerThread,
out_n_k_ho_wo_desc, GemmMLevel0Cluster,
conv_strides, GemmNLevel0Cluster,
conv_dilations, GemmMLevel1Cluster,
in_left_pads, GemmNLevel1Cluster,
in_right_pads, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
wei_k_c_y_x_device_buf.GetDeviceBuffer()), Sequence<1, 0>,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( Sequence<1, 0>,
in_n_c_hi_wi_device_buf.GetDeviceBuffer()), 0,
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer())); GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(descs[I4]),
decltype(descs[I5]),
decltype(descs[I6]),
decltype(descs[I7]),
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I0],
descs[I1],
descs[I2],
descs[I3],
descs[I4],
descs[I5],
descs[I6],
descs[I7],
descs[I8],
nrepeat);
float perf = (float)calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment