Commit 0d475c27 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 1c704471
...@@ -193,7 +193,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -193,7 +193,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
#if 0 // pass tensor descriptors by their reference #if 1 // pass tensor descriptors by their reference
index_t nrepeat = 100; index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
...@@ -850,389 +850,1079 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -850,389 +850,1079 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop) #if 1 // pass tensor descriptors by their reference
{ index_t nrepeat = 100;
const auto kernel =
run_gridwise_operation<gridwise_gemm, for(index_t i = 0; i < 5; ++i)
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = std::cout << "Start running " << nrepeat << " times..." << std::endl;
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), KernelTimer timer;
const Float*, timer.Start();
decltype(in_gemmk_gemmn_global_desc),
const Float*, for(index_t j = 0; j < nrepeat; ++j)
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), {
Float*, if(has_main_k_block_loop && has_double_tail_k_block_loop)
integral_constant<bool, false>, {
integral_constant<bool, true>>; const auto kernel =
run_gridwise_operation<gridwise_gemm,
launch_kernel(kernel, decltype(wei_gemmk_gemmm_global_desc),
dim3(GridSize), const Float*,
dim3(BlockSize), decltype(in_gemmk_gemmn_global_desc),
0, const Float*,
0, decltype(
wei_gemmk_gemmm_global_desc, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
p_wei_global, Float*,
in_gemmk_gemmn_global_desc, integral_constant<bool, true>,
p_in_global, integral_constant<bool, true>>;
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global, launch_kernel(kernel,
integral_constant<bool, false>{}, dim3(GridSize),
integral_constant<bool, true>{}); dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
} }
else #elif 1 // pass tensor descriptors by their pointers
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{ {
const auto kernel = std::cout << "Start running " << nrepeat << " times..." << std::endl;
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
};
template <index_t BlockSize, KernelTimer timer;
typename Float, timer.Start();
typename AccFloat,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockTransferSrcScalarPerVector_GemmK,
index_t GemmABlockTransferDstScalarPerVector_GemmM,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN,
index_t GemmBBlockTransferDstScalarPerVector_GemmN,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1>
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
{
template <typename... Wei, typename... In, typename... Out>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const MultiIndex<2> conv_strides,
const MultiIndex<2> conv_dilations,
const MultiIndex<2> in_left_pads,
const MultiIndex<2> in_right_pads,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const index_t N = in_n_c_hi_wi_global_desc.GetLength(I0); for(index_t j = 0; j < nrepeat; ++j)
const index_t C = in_n_c_hi_wi_global_desc.GetLength(I1); {
const index_t K = out_n_k_ho_wo_global_desc.GetLength(I1); if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif 1 // pass tensor descriptor by void*
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const Float*,
const void*,
const Float*,
const void*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const Float*,
const void*,
const Float*,
const void*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const Float*,
const void*,
const Float*,
const void*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const Float*,
const void*,
const Float*,
const void*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
};
template <index_t BlockSize,
typename Float,
typename AccFloat,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockTransferSrcScalarPerVector_GemmK,
index_t GemmABlockTransferDstScalarPerVector_GemmM,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN,
index_t GemmBBlockTransferDstScalarPerVector_GemmN,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1>
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
{
template <typename... Wei, typename... In, typename... Out>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const MultiIndex<2> conv_strides,
const MultiIndex<2> conv_dilations,
const MultiIndex<2> in_left_pads,
const MultiIndex<2> in_right_pads,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
const index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
const index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
const index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
const index_t ConvStrideH = conv_strides[I0];
const index_t ConvStrideW = conv_strides[I1];
const index_t ConvDilationH = conv_dilations[I0];
const index_t ConvDilationW = conv_dilations[I1];
const index_t InLeftPadH = in_left_pads[I0];
const index_t InLeftPadW = in_left_pads[I1];
const index_t InRightPadH = in_right_pads[I0];
const index_t InRightPadW = in_right_pads[I1];
if(!(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0))
{
throw std::runtime_error("wrong! 1x1, stride 1, no padding");
}
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed<2>(make_multi_index(K, C)),
make_tuple(DynamicPassThrough{K}, DynamicPassThrough{C}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(DynamicPassThrough{C}, DynamicMerge<3>{make_multi_index(N, Ho, Wo)}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed<3>(make_multi_index(N, K, Ho * Wo)),
make_tuple(DynamicPassThrough{K}, DynamicMerge<2>{make_multi_index(N, Ho * Wo)}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const index_t GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const index_t GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
const index_t GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
const index_t GemmM0 = GemmM / GemmM1;
const index_t GemmN0 = GemmN / GemmN1;
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(DynamicUnMerge<2>{make_multi_index(GemmM0, GemmM1)},
DynamicUnMerge<2>{make_multi_index(GemmN0, GemmN1)}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1<
BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
#if 1 // pass tensor descriptors by their reference
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif 1 // pass tensor descriptors by their pointers
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2); timer.End();
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2); float ave_time = timer.GetElapsedTime() / nrepeat;
const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
const index_t X = wei_k_c_y_x_global_desc.GetLength(I3); wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
const index_t ConvStrideH = conv_strides[I0]; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
const index_t ConvStrideW = conv_strides[I1]; << std::endl;
}
#elif 1 // pass tensor descriptor by void*
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
const index_t ConvDilationH = conv_dilations[I0]; DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
const index_t ConvDilationW = conv_dilations[I1]; DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
const index_t InLeftPadH = in_left_pads[I0]; wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
const index_t InLeftPadW = in_left_pads[I1]; in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
const index_t InRightPadH = in_right_pads[I0]; index_t nrepeat = 100;
const index_t InRightPadW = in_right_pads[I1];
if(!(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && for(index_t i = 0; i < 5; ++i)
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0))
{ {
throw std::runtime_error("wrong! 1x1, stride 1, no padding"); std::cout << "Start running " << nrepeat << " times..." << std::endl;
}
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed<2>(make_multi_index(K, C)),
make_tuple(DynamicPassThrough{K}, DynamicPassThrough{C}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(DynamicPassThrough{C}, DynamicMerge<3>{make_multi_index(N, Ho, Wo)}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed<3>(make_multi_index(N, K, Ho * Wo)),
make_tuple(DynamicPassThrough{K}, DynamicMerge<2>{make_multi_index(N, Ho * Wo)}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const index_t GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); KernelTimer timer;
const index_t GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); timer.Start();
const index_t GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && for(index_t j = 0; j < nrepeat; ++j)
GemmK % GemmKPerBlock == 0)) {
{ if(has_main_k_block_loop && has_double_tail_k_block_loop)
throw std::runtime_error("wrong! GEMM size no divisible"); {
} const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const Float*,
const void*,
const Float*,
const void*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster; launch_kernel(kernel,
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster; dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const Float*,
const void*,
const Float*,
const void*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
const index_t GemmM0 = GemmM / GemmM1; launch_kernel(kernel,
const index_t GemmN0 = GemmN / GemmN1; dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const Float*,
const void*,
const Float*,
const void*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = launch_kernel(kernel,
transform_dynamic_tensor_descriptor( dim3(GridSize),
out_gemmm_gemmn_global_desc, dim3(BlockSize),
make_tuple(DynamicUnMerge<2>{make_multi_index(GemmM0, GemmM1)}, 0,
DynamicUnMerge<2>{make_multi_index(GemmN0, GemmN1)}), 0,
make_tuple(Sequence<0>{}, Sequence<1>{}), wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const Float*,
const void*,
const Float*,
const void*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
// GEMM launch_kernel(kernel,
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1< dim3(GridSize),
BlockSize, dim3(BlockSize),
Float, 0,
AccFloat, 0,
InMemoryDataOperation::Set, wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
decltype(wei_gemmk_gemmm_global_desc), p_wei_global,
decltype(in_gemmk_gemmn_global_desc), in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), p_in_global,
GemmMPerBlock, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
GemmNPerBlock, .GetDeviceBuffer(),
GemmKPerBlock, p_out_global,
GemmMPerThread, integral_constant<bool, false>{},
GemmNPerThread, integral_constant<bool, false>{});
GemmKPerThread, }
GemmMLevel0Cluster, }
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); timer.End();
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; float ave_time = timer.GetElapsedTime() / nrepeat;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
if(has_main_k_block_loop && has_double_tail_k_block_loop) std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
{ << std::endl;
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
} }
#endif
} }
}; };
......
...@@ -70,6 +70,14 @@ struct DynamicPassThrough ...@@ -70,6 +70,14 @@ struct DynamicPassThrough
{ {
return true; return true;
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicPassThrough, ");
print_multi_index(up_lengths_);
printf("}");
}
}; };
template <bool SkipIsValidCheck = false> template <bool SkipIsValidCheck = false>
...@@ -145,6 +153,17 @@ struct DynamicPad ...@@ -145,6 +153,17 @@ struct DynamicPad
return SkipIsValidCheck || ((idx_up[Number<0>{}] >= left_pad_) && return SkipIsValidCheck || ((idx_up[Number<0>{}] >= left_pad_) &&
(idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_)); (idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_));
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicPad, ");
print_multi_index(up_lengths_);
printf("left_pad_ %d", left_pad_);
printf(", ");
printf("right_pad_ %d", right_pad_);
printf("}");
}
}; };
template <bool SkipIsValidCheck = false> template <bool SkipIsValidCheck = false>
...@@ -214,6 +233,15 @@ struct DynamicLeftPad ...@@ -214,6 +233,15 @@ struct DynamicLeftPad
{ {
return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_); return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_);
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicLeftPad, ");
print_multi_index(up_lengths_);
printf("left_pad_ %d", left_pad_);
printf("}");
}
}; };
template <bool SkipIsValidCheck = false> template <bool SkipIsValidCheck = false>
...@@ -287,6 +315,15 @@ struct DynamicRightPad ...@@ -287,6 +315,15 @@ struct DynamicRightPad
{ {
return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_); return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_);
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicRightPad, ");
print_multi_index(up_lengths_);
printf("left_pad_ %d", right_pad_);
printf("}");
}
}; };
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
...@@ -364,6 +401,17 @@ struct DynamicEmbed ...@@ -364,6 +401,17 @@ struct DynamicEmbed
{ {
return true; return true;
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicEmbed, ");
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("coefficients_ ");
print_multi_index(coefficients_);
printf("}");
}
}; };
template <index_t NDimLow> template <index_t NDimLow>
...@@ -859,7 +907,20 @@ struct DynamicMerge ...@@ -859,7 +907,20 @@ struct DynamicMerge
{ {
return true; return true;
} }
}; // namespace ck
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicMerge, ");
printf("low_lengths_ ");
print_multi_index(low_lengths_);
printf("low_lengths_scan_ ");
print_multi_index(low_lengths_scan_);
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("}");
}
};
template <index_t NDimUp, bool Use24BitIntegerCalculation = false> template <index_t NDimUp, bool Use24BitIntegerCalculation = false>
struct DynamicUnMerge struct DynamicUnMerge
...@@ -938,6 +999,15 @@ struct DynamicUnMerge ...@@ -938,6 +999,15 @@ struct DynamicUnMerge
{ {
return true; return true;
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicUnMerge, ");
print_multi_index(up_lengths_);
print_multi_index(up_lengths_scan_);
printf("}");
}
}; };
struct DynamicFreeze struct DynamicFreeze
...@@ -997,6 +1067,8 @@ struct DynamicFreeze ...@@ -997,6 +1067,8 @@ struct DynamicFreeze
{ {
return true; return true;
} }
__host__ __device__ void Print() const { printf("DynamicFreeze"); }
}; };
} // namespace ck } // namespace ck
......
...@@ -146,6 +146,23 @@ struct DynamicTensorDescriptor ...@@ -146,6 +146,23 @@ struct DynamicTensorDescriptor
return hidden_lengths; return hidden_lengths;
} }
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicTensorDescriptor, ");
static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: ");
transforms_[i].Print();
printf("LowerDimensionIds:");
LowerDimensionIdss{}.At(i).Print();
printf("UpperDimensionIds:");
UpperDimensionIdss{}.At(i).Print();
});
printf("}");
VisibleDimensionIds::Print();
}
// TODO make these private // TODO make these private
Transforms transforms_; Transforms transforms_;
// TODO maybe hidden_lengths_ should use reference_wrapper (reference to transforms_'s member // TODO maybe hidden_lengths_ should use reference_wrapper (reference to transforms_'s member
......
...@@ -163,6 +163,16 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x) ...@@ -163,6 +163,16 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
return r; return r;
} }
template <typename... Xs>
__host__ __device__ void print_multi_index(const Tuple<Xs...>& x)
{
printf("{");
printf("MultiIndex, ");
printf("size %d,", index_t{sizeof...(Xs)});
static_for<0, sizeof...(Xs), 1>{}([&](auto i) { printf("%d ", x.At(i)); });
printf("}");
}
#endif #endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -278,7 +278,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -278,7 +278,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_k_n_global_move_slice_window_iterator_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
#elif 1 #elif 0
// for non-padded input // for non-padded input
constexpr auto b_k_n_global_iterator_hacks = make_tuple( constexpr auto b_k_n_global_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}), make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}),
......
...@@ -168,6 +168,14 @@ struct Sequence ...@@ -168,6 +168,14 @@ struct Sequence
{ {
return Sequence<f(Is)...>{}; return Sequence<f(Is)...>{};
} }
__host__ __device__ static void Print()
{
printf("{");
printf("size %d, ", index_t{Size()});
static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); });
printf("}");
}
}; };
// merge sequence // merge sequence
......
...@@ -235,7 +235,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -235,7 +235,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr auto conv_driver = constexpr auto conv_driver =
#if 1 #if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
#elif 1 #elif 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
#elif 1 #elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
......
...@@ -67,7 +67,7 @@ int main(int argc, char* argv[]) ...@@ -67,7 +67,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 1
// 1x1, 8x8 // 1x1, 8x8
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1536; constexpr index_t C = 1536;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment