"...composable_kernel_rocm.git" did not exist on "9e33fe70c34de4816928a0d8bdf2458fe411a589"
Commit d8c89b68 authored by Chao Liu's avatar Chao Liu
Browse files

refactor driver for conv

parent fd160c63
...@@ -4,1547 +4,441 @@ ...@@ -4,1547 +4,441 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp" #include "driver_dynamic_gemm_v1.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck { namespace ck {
// GemmM = K // GemmM = K
// GemmN = N * Ho * Wo // GemmN = N * Ho * Wo
// GemmK = C * Y * X // GemmK = C * Y * X
template <index_t BlockSize, template <index_t GemmMPerBlock,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmM1,
index_t GemmMPerThread, index_t GemmN1,
index_t GemmNPerThread, typename... Wei,
index_t GemmKPerThread, typename... In,
index_t GemmMLevel0Cluster, typename... Out,
index_t GemmNLevel0Cluster, typename ConvStrides,
index_t GemmMLevel1Cluster, typename ConvDilations,
index_t GemmNLevel1Cluster, typename InLeftPads,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename InRightPads>
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
index_t GemmABlockTransferSrcScalarPerVector_GemmK, const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
index_t GemmABlockTransferDstScalarPerVector_GemmM, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, const ConvStrides& conv_strides,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN, const ConvDilations& conv_dilations,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, const InLeftPads& in_left_pads,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1> const InRightPads& in_right_pads)
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{ {
template <typename... Wei, constexpr auto I0 = Number<0>{};
typename... In, constexpr auto I1 = Number<1>{};
typename... Out, constexpr auto I2 = Number<2>{};
typename ConvStrides, constexpr auto I3 = Number<3>{};
typename ConvDilations,
typename InLeftPads, const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
typename InRightPads> const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc, const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const ConvStrides& conv_strides, const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const InRightPads& in_right_pads, const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
FloatC* __restrict__ p_out_global) const const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
{
constexpr auto I0 = Number<0>{}; const auto ConvStrideH = conv_strides[I0];
constexpr auto I1 = Number<1>{}; const auto ConvStrideW = conv_strides[I1];
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); const auto InLeftPadH = in_left_pads[I0];
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); const auto InLeftPadW = in_left_pads[I1];
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); const auto InRightPadH = in_right_pads[I0];
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); const auto InRightPadW = in_right_pads[I1];
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); // weight tensor
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
const auto X = wei_k_c_y_x_global_desc.GetLength(I3); make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1]; // input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
const auto ConvDilationH = conv_dilations[I0]; in_n_c_hi_wi_global_desc,
const auto ConvDilationW = conv_dilations[I1]; make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
const auto InLeftPadH = in_left_pads[I0]; make_pad_transform(Hi, InLeftPadH, InRightPadH),
const auto InLeftPadW = in_left_pads[I1]; make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
const auto InRightPadH = in_right_pads[I0]; make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto InRightPadW = in_right_pads[I1];
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
// weight tensor in_n_c_hip_wip_global_desc,
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( make_tuple(make_pass_through_transform(N),
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), make_pass_through_transform(C),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( const auto in_gemmk_gemmn_global_desc =
in_n_c_hi_wi_global_desc, transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_pad_transform(Wi, InLeftPadW, InRightPadW)), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); // output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
in_n_c_hip_wip_global_desc, make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple( make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_pass_through_transform(N), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, const auto GemmM0 = GemmM / Number<GemmM1>{};
make_tuple(make_merge_transform(make_tuple(C, Y, X)), const auto GemmN0 = GemmN / Number<GemmN1>{};
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
make_tuple(Sequence<0>{}, Sequence<1>{})); out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
// output tensor make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( make_tuple(Sequence<0>{}, Sequence<1>{}),
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
make_tuple(make_pass_through_transform(K),
make_merge_transform(make_tuple(N, Ho * Wo))), // out_gemm_block_cluster_desc
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); // hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0)) constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
{
throw std::runtime_error("wrong! GEMM size no divisible"); // hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
} constexpr auto in_gemmk_gemmn_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{}; make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
const auto GemmM0 = GemmM / GemmM1;
const auto GemmN0 = GemmN / GemmN1; constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor( // hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
out_gemmm_gemmn_global_desc, // tensor hack for NKHW format
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_unmerge_transform(make_tuple(GemmN0, GemmN1))), make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<0, 0, 0, 0, 0>{},
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
// c_block_cluster_desc make_tuple(Sequence<0, 0, 0, 0, 0>{},
const auto gemm_block_cluster_desc = make_cluster_descriptor_v2( Sequence<0, 0, 0, 0, 0>{},
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{})); Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_k_m_global_iterator_hacks = return make_tuple(wei_gemmk_gemmm_global_desc,
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), in_gemmk_gemmn_global_desc,
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
out_gemm_block_cluster_desc,
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; wei_gemmk_gemmm_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks,
// hack to control index calculation when iterating over b_k_n_global tensor out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
constexpr auto b_k_n_global_iterator_hacks = wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), }
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
decltype(gemm_block_cluster_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
};
#if 0
// GemmM = K // GemmM = K
// GemmN = N * Ho * Wo // GemmN = N * Ho * Wo
// GemmK = C * Y * X // GemmK = C * Y * X
template <index_t BlockSize, template <index_t GemmMPerBlock,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmM1,
index_t GemmMPerThread, index_t GemmN1,
index_t GemmNPerThread, typename... Wei,
index_t GemmKPerThread, typename... In,
index_t GemmMLevel0Cluster, typename... Out,
index_t GemmNLevel0Cluster, typename ConvStrides,
index_t GemmMLevel1Cluster, typename ConvDilations,
index_t GemmNLevel1Cluster, typename InLeftPads,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename InRightPads>
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, __host__ __device__ constexpr auto
index_t GemmABlockTransferSrcScalarPerVector_GemmK, transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
index_t GemmABlockTransferDstScalarPerVector_GemmM, const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN, const ConvStrides& conv_strides,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, const ConvDilations& conv_dilations,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1> const InLeftPads& in_left_pads,
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad const InRightPads& in_right_pads)
{ {
template <typename... Wei, constexpr auto I0 = Number<0>{};
typename... In, constexpr auto I1 = Number<1>{};
typename... Out, constexpr auto I2 = Number<2>{};
typename ConvStrides, constexpr auto I3 = Number<3>{};
typename ConvDilations,
typename InLeftPads, const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
typename InRightPads> const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc, const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const ConvStrides& conv_strides, const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const InRightPads& in_right_pads, const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
FloatC* __restrict__ p_out_global) const const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
{
constexpr auto I0 = Number<0>{}; const auto ConvStrideH = conv_strides[I0];
constexpr auto I1 = Number<1>{}; const auto ConvStrideW = conv_strides[I1];
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); const auto InLeftPadH = in_left_pads[I0];
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); const auto InLeftPadW = in_left_pads[I1];
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); const auto InRightPadH = in_right_pads[I0];
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); const auto InRightPadW = in_right_pads[I1];
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
// weight tensor
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
const auto X = wei_k_c_y_x_global_desc.GetLength(I3); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
const auto ConvStrideH = conv_strides[I0]; make_tuple(Sequence<0>{}, Sequence<1>{}),
const auto ConvStrideW = conv_strides[I1]; make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto ConvDilationH = conv_dilations[I0]; // input tensor
const auto ConvDilationW = conv_dilations[I1]; const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
const auto InLeftPadH = in_left_pads[I0]; make_tuple(make_pass_through_transform(N),
const auto InLeftPadW = in_left_pads[I1]; make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
const auto InRightPadH = in_right_pads[I0]; make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
const auto InRightPadW = in_right_pads[I1]; make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
if(!(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0))
{ const auto in_gemmk_gemmn_global_desc =
throw std::runtime_error("wrong! no padding"); transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
} make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
// weight tensor make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( make_tuple(Sequence<0>{}, Sequence<1>{}));
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), // output tensor
make_tuple(Sequence<0>{}, Sequence<1>{}), const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_tuple(Sequence<1>{}, Sequence<0>{})); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
// input tensor make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( make_tuple(Sequence<0>{}, Sequence<1>{}));
in_n_c_hi_wi_global_desc,
make_tuple( const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
make_pass_through_transform(N), const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
make_pass_through_transform(C), const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); const auto GemmM0 = GemmM / Number<GemmM1>{};
const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
make_tuple(make_merge_transform(make_tuple(C, Y, X)), out_gemmm_gemmn_global_desc,
make_merge_transform(make_tuple(N, Ho, Wo))), make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( // out_gemm_block_cluster_desc
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(make_pass_through_transform(K), make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), // hack to control index calculation when iterating over a_k_m_global tensor
make_tuple(Sequence<0>{}, Sequence<1>{})); constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && // hack to control index calculation when iterating over b_k_n_global tensor
GemmK % GemmKPerBlock == 0)) constexpr auto in_gemmk_gemmn_global_iterator_hacks =
{ make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}),
throw std::runtime_error("wrong! GEMM size no divisible"); make_tuple(Sequence<0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 2>{}));
}
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{}; Sequence<0, 0, 0, 0, 0, 1, 2>{};
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
const auto GemmM0 = GemmM / GemmM1; // hack for NKHW format
const auto GemmN0 = GemmN / GemmN1; constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = Sequence<0, 0, 0, 0, 0>{},
transform_dynamic_tensor_descriptor( Sequence<0, 0, 1, 0, 0>{},
out_gemmm_gemmn_global_desc, Sequence<0, 0, 1, 0, 0>{}),
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), make_tuple(Sequence<0, 0, 0, 0, 0>{},
make_unmerge_transform(make_tuple(GemmN0, GemmN1))), Sequence<0, 0, 0, 0, 0>{},
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<0, 0, 2, 0, 0>{},
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); Sequence<0, 0, 2, 0, 0>{}));
// hack to control index calculation when iterating over a_k_m_global tensor return make_tuple(wei_gemmk_gemmm_global_desc,
constexpr auto a_k_m_global_iterator_hacks = in_gemmk_gemmn_global_desc,
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); out_gemm_block_cluster_desc,
wei_gemmk_gemmm_global_iterator_hacks,
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; in_gemmk_gemmn_global_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
// hack to control index calculation when iterating over b_k_n_global tensor wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
constexpr auto b_k_n_global_iterator_hacks = make_tuple( in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}), }
make_tuple(Sequence<0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 2>{}));
// GemmM = K
constexpr auto b_k_n_global_move_slice_window_iterator_hack = // GemmN = N * Ho * Wo
Sequence<0, 0, 0, 0, 0, 1, 2>{}; // GemmK = C * Y * X
template <index_t GemmMPerBlock,
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
};
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmM1,
index_t GemmMPerThread, index_t GemmN1,
index_t GemmNPerThread, typename... Wei,
index_t GemmKPerThread, typename... In,
index_t GemmMLevel0Cluster, typename... Out,
index_t GemmNLevel0Cluster, typename ConvStrides,
index_t GemmMLevel1Cluster, typename ConvDilations,
index_t GemmNLevel1Cluster, typename InLeftPads,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename InRightPads>
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1(
index_t GemmABlockTransferSrcScalarPerVector_GemmK, const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
index_t GemmABlockTransferDstScalarPerVector_GemmM, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, const ConvStrides& conv_strides,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN, const ConvDilations& conv_dilations,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, const InLeftPads& in_left_pads,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1> const InRightPads& in_right_pads)
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
{ {
template <typename... Wei, constexpr auto I0 = Number<0>{};
typename... In, constexpr auto I1 = Number<1>{};
typename... Out, constexpr auto I2 = Number<2>{};
typename ConvStrides, constexpr auto I3 = Number<3>{};
typename ConvDilations,
typename InLeftPads, const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
typename InRightPads> const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc, const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const ConvStrides& conv_strides, const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const InRightPads& in_right_pads, const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
FloatC* __restrict__ p_out_global) const const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
{
constexpr auto I0 = Number<0>{}; const auto ConvStrideH = conv_strides[I0];
constexpr auto I1 = Number<1>{}; const auto ConvStrideW = conv_strides[I1];
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); const auto InLeftPadH = in_left_pads[I0];
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); const auto InLeftPadW = in_left_pads[I1];
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); const auto InRightPadH = in_right_pads[I0];
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); const auto InRightPadW = in_right_pads[I1];
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3); // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
const auto ConvStrideH = conv_strides[I0]; make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
const auto ConvStrideW = conv_strides[I1]; make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
const auto ConvDilationH = conv_dilations[I0]; make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto ConvDilationW = conv_dilations[I1];
// input tensor
const auto InLeftPadH = in_left_pads[I0]; const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
const auto InLeftPadW = in_left_pads[I1]; in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
const auto InRightPadH = in_right_pads[I0]; make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
const auto InRightPadW = in_right_pads[I1]; make_tuple(Sequence<0>{}, Sequence<1>{}));
if(!(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && // output tensor
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
InRightPadW == 0)) make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
{ make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
throw std::runtime_error("wrong! 1x1, stride 1, no padding"); make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
} make_tuple(Sequence<0>{}, Sequence<1>{}));
// weight tensor const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto GemmM0 = GemmM / Number<GemmM1>{};
// input tensor const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc, const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), out_gemmm_gemmn_global_desc,
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
// output tensor make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), // out_gemm_block_cluster_desc
make_tuple(make_pass_through_transform(K), const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_merge_transform(make_tuple(N, Ho * Wo))), make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); // hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0)) // hack to control index calculation when iterating over b_k_n_global tensor
{ constexpr auto in_gemmk_gemmn_global_iterator_hacks =
throw std::runtime_error("wrong! GEMM size no divisible"); make_tuple(make_tuple(Sequence<0, 1, 0>{}, Sequence<0, 0, 1>{}),
} make_tuple(Sequence<0, 2, 0>{}, Sequence<0, 0, 2>{}));
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{}; constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = Sequence<0, 1, 2>{};
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
const auto GemmM0 = GemmM / GemmM1; constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
const auto GemmN0 = GemmN / GemmN1; make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = Sequence<0, 0, 1, 0, 0>{},
transform_dynamic_tensor_descriptor( Sequence<0, 0, 1, 0, 0>{}),
out_gemmm_gemmn_global_desc, make_tuple(Sequence<0, 0, 0, 0, 0>{},
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), Sequence<0, 0, 0, 0, 0>{},
make_unmerge_transform(make_tuple(GemmN0, GemmN1))), Sequence<0, 0, 2, 0, 0>{},
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<0, 0, 2, 0, 0>{}));
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return make_tuple(wei_gemmk_gemmm_global_desc,
// hack to control index calculation when iterating over a_k_m_global tensor in_gemmk_gemmn_global_desc,
constexpr auto a_k_m_global_iterator_hacks = out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), out_gemm_block_cluster_desc,
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); wei_gemmk_gemmm_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks,
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
// hack to control index calculation when iterating over b_k_n_global tensor in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
constexpr auto b_k_n_global_iterator_hacks = }
make_tuple(make_tuple(Sequence<0, 1, 0>{}, Sequence<0, 0, 1>{}),
make_tuple(Sequence<0, 2, 0>{}, Sequence<0, 0, 2>{}));
constexpr auto b_k_n_global_move_slice_window_iterator_hack = Sequence<0, 1, 2>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
};
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -4,1030 +4,297 @@ ...@@ -4,1030 +4,297 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp" #include "driver_dynamic_gemm_v1.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck { namespace ck {
// GemmM = K // GemmM = K
// GemmN = N * Ho * Wo // GemmN = N * Ho * Wo
// GemmK = Y * X * C // GemmK = C * Y * X
template <index_t BlockSize, template <index_t GemmMPerBlock,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmM1,
index_t GemmMPerThread, index_t GemmN1,
index_t GemmNPerThread, typename... Wei,
index_t GemmKPerThread, typename... In,
index_t GemmMLevel0Cluster, typename... Out,
index_t GemmNLevel0Cluster, typename ConvStrides,
index_t GemmMLevel1Cluster, typename ConvDilations,
index_t GemmNLevel1Cluster, typename InLeftPads,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename InRightPads>
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(
index_t GemmABlockTransferSrcScalarPerVector_GemmK, const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_global_desc,
index_t GemmABlockTransferDstScalarPerVector_GemmM, const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_global_desc,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_global_desc,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, const ConvStrides& conv_strides,
index_t GemmBBlockTransferSrcScalarPerVector_GemmK, const ConvDilations& conv_dilations,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, const InLeftPads& in_left_pads,
index_t GemmCThreadTransferDstScalarPerVector_GemmM1> const InRightPads& in_right_pads)
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
{ {
template <typename... Wei, constexpr auto I0 = Number<0>{};
typename... In, constexpr auto I1 = Number<1>{};
typename... Out, constexpr auto I2 = Number<2>{};
typename ConvStrides, constexpr auto I3 = Number<3>{};
typename ConvDilations,
typename InLeftPads, const auto N = in_n_hi_wi_c_global_desc.GetLength(I0);
typename InRightPads> const auto C = in_n_hi_wi_c_global_desc.GetLength(I3);
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_global_desc, const auto K = out_n_ho_wo_k_global_desc.GetLength(I3);
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_global_desc, const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1);
const ConvStrides& conv_strides, const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2);
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1);
const InRightPads& in_right_pads, const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2);
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const auto Y = wei_k_y_x_c_global_desc.GetLength(I1);
FloatC* __restrict__ p_out_global) const const auto X = wei_k_y_x_c_global_desc.GetLength(I2);
{
constexpr auto I0 = Number<0>{}; const auto ConvStrideH = conv_strides[I0];
constexpr auto I1 = Number<1>{}; const auto ConvStrideW = conv_strides[I1];
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_n_hip_wip_c_global_desc = transform_dynamic_tensor_descriptor(
in_n_hi_wi_c_global_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_global_desc = transform_dynamic_tensor_descriptor(
in_n_hip_wip_c_global_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_global_desc =
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_global_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
const auto GemmM0 = GemmM / Number<GemmM1>{};
const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// out_gemm_block_cluster_desc
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over b_k_n_global tensor
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk_gemmm_global_desc,
in_gemmk_gemmn_global_desc,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
out_gemm_block_cluster_desc,
wei_gemmk_gemmm_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
}
const auto N = in_n_hi_wi_c_global_desc.GetLength(I0); // GemmM = K
const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); // GemmN = N * Ho * Wo
const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); // GemmK = C * Y * X
template <index_t GemmMPerBlock,
const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_global_desc.GetLength(I1);
const auto X = wei_k_y_x_c_global_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_n_hip_wip_c_global_desc = transform_dynamic_tensor_descriptor(
in_n_hi_wi_c_global_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_global_desc = transform_dynamic_tensor_descriptor(
in_n_hip_wip_c_global_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_y_ho_x_wo_c_global_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{};
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{};
const auto GemmM0 = GemmM / GemmM1;
const auto GemmN0 = GemmN / GemmN1;
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// c_block_cluster_desc
const auto gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_k_m_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over b_k_n_global tensor
constexpr auto b_k_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
decltype(gemm_block_cluster_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmBBlockTransferSrcScalarPerVector_GemmK,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
1,
GemmCThreadTransferDstScalarPerVector_GemmM1,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize);
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
decltype(gemm_block_cluster_desc),
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
gemm_block_cluster_desc,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
};
#if 0
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmM1,
index_t GemmMPerThread, index_t GemmN1,
index_t GemmNPerThread, typename... Wei,
index_t GemmKPerThread, typename... In,
index_t GemmMLevel0Cluster, typename... Out,
index_t GemmNLevel0Cluster, typename ConvStrides,
index_t GemmMLevel1Cluster, typename ConvDilations,
index_t GemmNLevel1Cluster, typename InLeftPads,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename InRightPads>
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(
index_t GemmABlockTransferSrcScalarPerVector_GemmK, const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_global_desc,
index_t GemmABlockTransferDstScalarPerVector_GemmM, const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_global_desc,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_global_desc,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, const ConvStrides& conv_strides,
index_t GemmBBlockTransferSrcScalarPerVector_GemmK, const ConvDilations& conv_dilations,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, const InLeftPads& in_left_pads,
index_t GemmCThreadTransferDstScalarPerVector_GemmM1> const InRightPads& in_right_pads)
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
{ {
template <typename... Wei, constexpr auto I0 = Number<0>{};
typename... In, constexpr auto I1 = Number<1>{};
typename... Out, constexpr auto I2 = Number<2>{};
typename ConvStrides, constexpr auto I3 = Number<3>{};
typename ConvDilations,
typename InLeftPads, const auto N = in_n_hi_wi_c_global_desc.GetLength(I0);
typename InRightPads> const auto C = in_n_hi_wi_c_global_desc.GetLength(I3);
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_global_desc, const auto K = out_n_ho_wo_k_global_desc.GetLength(I3);
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_global_desc, const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1);
const ConvStrides& conv_strides, const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2);
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1);
const InRightPads& in_right_pads, const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2);
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const auto Y = wei_k_y_x_c_global_desc.GetLength(I1);
FloatC* __restrict__ p_out_global) const const auto X = wei_k_y_x_c_global_desc.GetLength(I2);
{
constexpr auto I0 = Number<0>{}; const auto ConvStrideH = conv_strides[I0];
constexpr auto I1 = Number<1>{}; const auto ConvStrideW = conv_strides[I1];
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto N = in_n_hi_wi_c_global_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); const auto InLeftPadH = in_left_pads[I0];
const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); const auto InLeftPadW = in_left_pads[I1];
const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1); const auto InRightPadH = in_right_pads[I0];
const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2); const auto InRightPadW = in_right_pads[I1];
const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1); assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2); ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0);
const auto Y = wei_k_y_x_c_global_desc.GetLength(I1);
const auto X = wei_k_y_x_c_global_desc.GetLength(I2); // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
const auto ConvStrideH = conv_strides[I0]; make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
const auto ConvStrideW = conv_strides[I1]; make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
const auto ConvDilationH = conv_dilations[I0]; make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto ConvDilationW = conv_dilations[I1];
// input tensor
const auto InLeftPadH = in_left_pads[I0]; const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
const auto InLeftPadW = in_left_pads[I1]; make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)),
const auto InRightPadH = in_right_pads[I0]; make_tuple(Sequence<0>{}, Sequence<1>{}),
const auto InRightPadW = in_right_pads[I1]; make_tuple(Sequence<1>{}, Sequence<0>{}));
if(!(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && // output tensor
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
InRightPadW == 0)) make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
{ make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
throw std::runtime_error("wrong! 1x1, stride 1, no padding"); make_tuple(Sequence<0>{}, Sequence<1>{}),
} make_tuple(Sequence<1>{}, Sequence<0>{}));
// weight tensor const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}), assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto GemmM0 = GemmM / Number<GemmM1>{};
// input tensor const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)), const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), out_gemmm_gemmn_global_desc,
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
// output tensor make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), // out_gemm_block_cluster_desc
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
make_tuple(Sequence<1>{}, Sequence<0>{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_global_iterator_hacks
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); // tensor
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0)) constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
{
throw std::runtime_error("wrong! GEMM size no divisible"); // hack to control index calculation when iterating over b_k_n_global tensor
} constexpr auto in_gemmk_gemmn_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{}; make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{};
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
const auto GemmM0 = GemmM / GemmM1;
const auto GemmN0 = GemmN / GemmN1; // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
transform_dynamic_tensor_descriptor( Sequence<0, 0, 0, 0, 0>{},
out_gemmm_gemmn_global_desc, Sequence<0, 0, 0, 0, 0>{},
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), Sequence<0, 0, 0, 0, 0>{}),
make_unmerge_transform(make_tuple(GemmN0, GemmN1))), make_tuple(Sequence<0, 0, 0, 0, 0>{},
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<0, 0, 0, 0, 0>{},
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_k_m_global_iterator_hacks = return make_tuple(wei_gemmk_gemmm_global_desc,
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), in_gemmk_gemmn_global_desc,
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
out_gemm_block_cluster_desc,
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; wei_gemmk_gemmm_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks,
// hack to control index calculation when iterating over b_k_n_global tensor out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
constexpr auto b_k_n_global_iterator_hacks = wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); }
constexpr auto b_k_n_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmBBlockTransferSrcScalarPerVector_GemmK,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
1,
GemmCThreadTransferDstScalarPerVector_GemmM1,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize);
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_gemmn_global_desc),
const FloatAB*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else
{
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
};
#endif
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1
#define CK_DRIVER_DYNAMIC_GEMM_V1
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
const FloatAB* p_b_global,
FloatC* p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
const auto K = a_k_m_global_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// GEMM
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGlobalDesc,
BGlobalDesc,
CGlobalDesc,
CBlockClusterDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerThread,
NPerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks>;
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, true>,
integral_constant<bool, true>>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, true>,
integral_constant<bool, false>>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, false>,
integral_constant<bool, true>>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, false>,
integral_constant<bool, false>>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global,
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global,
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global,
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global,
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
return ave_time;
#endif
}
} // namespace ck
#endif
...@@ -184,6 +184,27 @@ struct TensorAdaptor ...@@ -184,6 +184,27 @@ struct TensorAdaptor
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{}); return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
} }
__host__ __device__ void Print() const
{
printf("{");
printf("TensorAdaptor, ");
static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: ");
transforms_[i].Print();
printf("LowerDimensionHiddenIds:");
LowerDimensionHiddenIdss{}.At(i).Print();
printf("UpperDimensionHiddenIds:");
UpperDimensionHiddenIdss{}.At(i).Print();
});
printf("BottomDimensionHiddenIds:");
BottomDimensionHiddenIds::Print();
printf("TopDimensionHiddenIds:");
TopDimensionHiddenIds::Print();
printf("}");
}
private: private:
Transforms transforms_; Transforms transforms_;
ElementSize element_size_; ElementSize element_size_;
......
...@@ -12,7 +12,36 @@ ...@@ -12,7 +12,36 @@
namespace ck { namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA,
typename BGlobalDesc,
typename FloatB,
typename CGlobalDesc,
typename FloatC,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc,
const FloatA* __restrict__ p_a_global,
const BGlobalDesc b_k_n_global_desc,
const FloatB* __restrict__ p_b_global,
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
const CBlockClusterDesc c_block_cluster_desc)
{
GridwiseGemm{}.Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer // pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to // __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization // non-modifiable parameter address space, so compiler can enable corresponding optimization
...@@ -26,13 +55,13 @@ template <typename GridwiseGemm, ...@@ -26,13 +55,13 @@ template <typename GridwiseGemm,
typename CBlockClusterDesc, typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc, __global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc,
const FloatA* __restrict__ p_a_global, const FloatA* __restrict__ p_a_global,
const void __CONSTANT__* p_b_k_n_global_desc, const void __CONSTANT__* p_b_k_n_global_desc,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_c_block_cluster_desc) const void __CONSTANT__* p_c_block_cluster_desc)
{ {
// first cast void __CONSTANT__ void* to void* // first cast void __CONSTANT__ void* to void*
// second cast void* to Desc* // second cast void* to Desc*
......
...@@ -46,6 +46,7 @@ void launch_kernel(F kernel, ...@@ -46,6 +46,7 @@ void launch_kernel(F kernel,
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel(F kernel, float launch_and_time_kernel(F kernel,
int nrepeat,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
std::size_t lds_byte, std::size_t lds_byte,
...@@ -54,15 +55,32 @@ float launch_and_time_kernel(F kernel, ...@@ -54,15 +55,32 @@ float launch_and_time_kernel(F kernel,
{ {
KernelTimer timer; KernelTimer timer;
timer.Start(); printf("%s: block_dim {%d, %d, %d}, grid_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up\n");
// warm up
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
timer.End(); printf("Start running %d times...\n", nrepeat);
timer.Start();
hipGetLastError(); for(int i = 0; i < nrepeat; ++i)
{
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
}
return timer.GetElapsedTime(); timer.End();
return timer.GetElapsedTime() / nrepeat;
} }
#elif CK_DEVICE_BACKEND_NVIDIA #elif CK_DEVICE_BACKEND_NVIDIA
......
...@@ -29,8 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -29,8 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
{ {
using namespace ck; using namespace ck;
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw" std::cout << __func__ << std::endl;
<< std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
...@@ -459,50 +468,91 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -459,50 +468,91 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#endif #endif
constexpr auto conv_driver = constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
const auto descs =
#if 1 #if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad
#elif 0 #elif 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad
#elif 1 #else
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1
#endif #endif
<BlockSize, <GemmMPerBlock, GemmNPerBlock, GemmM1, GemmN1>(wei_k_c_y_x_desc,
typename vector_type<TInWei, InWeiVectorSize>::type, in_n_c_hi_wi_desc,
TAcc, out_n_k_ho_wo_desc,
TOut, conv_strides,
GemmMPerBlock, conv_dilations,
GemmNPerBlock, in_left_pads,
GemmKPerBlock, in_right_pads);
GemmMPerThread,
GemmNPerThread, float ave_time = launch_kernel_dynamic_gemm_v1<
GemmKPerThread, BlockSize,
GemmMLevel0Cluster, typename vector_type<TInWei, InWeiVectorSize>::type,
GemmNLevel0Cluster, TAcc,
GemmMLevel1Cluster, TOut,
GemmNLevel1Cluster, InMemoryDataOperation::Set,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, decltype(descs[I0]),
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, decltype(descs[I1]),
GemmABlockTransferSrcScalarPerVector_GemmK, decltype(descs[I2]),
GemmABlockTransferDstScalarPerVector_GemmM, decltype(descs[I3]),
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmMPerBlock,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmNPerBlock,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmKPerBlock,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmMPerThread,
GemmCThreadTransferDstScalarPerVector_GemmN1>{}; GemmNPerThread,
GemmKPerThread,
conv_driver.Run(wei_k_c_y_x_desc, GemmMLevel0Cluster,
in_n_c_hi_wi_desc, GemmNLevel0Cluster,
out_n_k_ho_wo_desc, GemmMLevel1Cluster,
conv_strides, GemmNLevel1Cluster,
conv_dilations, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
in_left_pads, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
in_right_pads, Sequence<1, 0>,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( Sequence<1, 0>,
wei_k_c_y_x_device_buf.GetDeviceBuffer()), 0,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( GemmABlockTransferSrcScalarPerVector_GemmK,
in_n_c_hi_wi_device_buf.GetDeviceBuffer()), GemmABlockTransferDstScalarPerVector_GemmM,
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer())); false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(descs[I4]),
decltype(descs[I5]),
decltype(descs[I6]),
decltype(descs[I7]),
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I0],
descs[I1],
descs[I2],
descs[I3],
descs[I4],
descs[I5],
descs[I6],
descs[I7],
descs[I8],
nrepeat);
float perf = (float)calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
} }
...@@ -29,13 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -29,13 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
{ {
using namespace ck; using namespace ck;
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk" std::cout << __func__ << std::endl;
<< std::endl;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
constexpr auto N = OutDesc::GetLengths()[I0]; constexpr auto N = OutDesc::GetLengths()[I0];
constexpr auto K = OutDesc::GetLengths()[I1]; constexpr auto K = OutDesc::GetLengths()[I1];
...@@ -372,51 +376,89 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -372,51 +376,89 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#endif #endif
constexpr auto conv_driver = constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
const auto descs =
#if 1 #if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad
#elif 0 #else
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1
#elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
#endif #endif
<BlockSize, <GemmMPerBlock, GemmNPerBlock, GemmM1, GemmN1>(wei_k_y_x_c0_desc,
typename vector_type<TInWei, InWeiVectorSize>::type, in_n_hi_wi_c0_desc,
TAcc, out_n_ho_wo_k_desc,
TOut, conv_strides,
GemmMPerBlock, conv_dilations,
GemmNPerBlock, in_left_pads,
GemmKPerBlock, in_right_pads);
GemmMPerThread,
GemmNPerThread, float ave_time = launch_kernel_dynamic_gemm_v1<
GemmKPerThread, BlockSize,
GemmMLevel0Cluster, typename vector_type<TInWei, InWeiVectorSize>::type,
GemmNLevel0Cluster, TAcc,
GemmMLevel1Cluster, TOut,
GemmNLevel1Cluster, InMemoryDataOperation::Set,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, decltype(descs[I0]),
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, decltype(descs[I1]),
GemmABlockTransferSrcScalarPerVector_GemmK, decltype(descs[I2]),
GemmABlockTransferDstScalarPerVector_GemmM, decltype(descs[I3]),
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmMPerBlock,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmNPerBlock,
GemmBBlockTransferSrcScalarPerVector_GemmK, GemmKPerBlock,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmMPerThread,
GemmCThreadTransferDstScalarPerVector_GemmM1>{}; GemmNPerThread,
GemmKPerThread,
conv_driver.Run(wei_k_y_x_c0_desc, GemmMLevel0Cluster,
in_n_hi_wi_c0_desc, GemmNLevel0Cluster,
out_n_ho_wo_k_desc, GemmMLevel1Cluster,
conv_strides, GemmNLevel1Cluster,
conv_dilations, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
in_left_pads, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
in_right_pads, Sequence<1, 0>,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( Sequence<1, 0>,
wei_k_y_x_c_device_buf.GetDeviceBuffer()), 0,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( GemmABlockTransferSrcScalarPerVector_GemmK,
in_n_hi_wi_c_device_buf.GetDeviceBuffer()), GemmABlockTransferDstScalarPerVector_GemmM,
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer())); false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmBBlockTransferSrcScalarPerVector_GemmK,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
1,
GemmCThreadTransferDstScalarPerVector_GemmM1,
decltype(descs[I4]),
decltype(descs[I5]),
decltype(descs[I6]),
decltype(descs[I7]),
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
descs[I0],
descs[I1],
descs[I2],
descs[I3],
descs[I4],
descs[I5],
descs[I6],
descs[I7],
descs[I8],
nrepeat);
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) { auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
......
...@@ -210,7 +210,7 @@ int main(int argc, char* argv[]) ...@@ -210,7 +210,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 1
// 3x3, 71x71 // 3x3, 71x71
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 192; constexpr index_t C = 192;
...@@ -225,7 +225,7 @@ int main(int argc, char* argv[]) ...@@ -225,7 +225,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
// 7x1, 17x17 // 7x1, 17x17
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -724,7 +724,7 @@ int main(int argc, char* argv[]) ...@@ -724,7 +724,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment