Unverified Commit 792a20fa authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Hybrid direct + implicit GEMM forward convolution NCHWc v5r1 (#25)

* Hybrid direct + implicit GEMM forward convolution NCHWc v5r1. Input tensor bypass LDS. Support fp32/fp16/int8
parent d2217f30
...@@ -38,7 +38,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -38,7 +38,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
typename InRightPads> typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, __host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc, const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -51,17 +51,21 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -51,17 +51,21 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3); const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
...@@ -78,7 +82,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -78,7 +82,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const auto InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
// weight tensor // weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
...@@ -104,7 +108,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -104,7 +108,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)), make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(N), make_pass_through_transform(N),
...@@ -114,31 +118,31 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -114,31 +118,31 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor // output tensor
const auto out_gemmm_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( const auto out_k_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_pass_through_transform(K), make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N), make_pass_through_transform(N),
make_pass_through_transform(Ho), make_pass_through_transform(Ho),
make_pass_through_transform(Wo)), make_pass_through_transform(Wo)),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto E = C * Y * X; const auto E = C * Y * X;
if(!(K % KPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0 && if(!((K % KPerBlock) == 0 && (Ho % HoPerBlock) == 0 && (Wo % WoPerBlock) == 0 &&
E % EPerBlock == 0)) (E % EPerBlock) == 0))
{ {
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
// hack to control index calculation when iterating over a_k_m_global tensor // hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_k_m_global_iterator_hacks = constexpr auto a_e_k_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
constexpr auto b_k_n_global_iterator_hacks = constexpr auto b_e_n_ho_wo_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
...@@ -148,17 +152,17 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -148,17 +152,17 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format // hack for NKHW format
constexpr auto c_k_n_h_w_global_tensor_iterator_hacks = constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}));
...@@ -171,9 +175,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -171,9 +175,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_e_k_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc), decltype(out_k_n_ho_wo_global_desc),
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
...@@ -196,13 +200,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -196,13 +200,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<0, 2, 3, 1>, Sequence<0, 2, 3, 1>,
3, 0,
CThreadTransferDstScalarPerVector_W, CThreadTransferDstScalarPerVector_W,
decltype(a_k_m_global_iterator_hacks), decltype(a_e_k_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks), decltype(b_e_n_ho_wo_global_iterator_hacks),
decltype(c_k_n_h_w_global_tensor_iterator_hacks), decltype(c_k_n_ho_wo_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack), decltype(a_e_k_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>; decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>;
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N; const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
...@@ -226,108 +230,104 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -226,108 +230,104 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
{ {
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel = run_gridwise_operation<gridwise_gemm,
run_gridwise_operation<gridwise_gemm, decltype(wei_e_k_global_desc),
decltype(wei_gemmk_gemmm_global_desc), const FloatAB*,
const FloatAB*, decltype(in_e_n_ho_wo_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc), const FloatAB*,
const FloatAB*, decltype(out_k_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc), FloatC*,
FloatC*, integral_constant<bool, true>,
integral_constant<bool, true>, integral_constant<bool, true>>;
integral_constant<bool, true>>;
launch_kernel(kernel, launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
wei_gemmk_gemmm_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_gemmk_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_gemmm_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel = run_gridwise_operation<gridwise_gemm,
run_gridwise_operation<gridwise_gemm, decltype(wei_e_k_global_desc),
decltype(wei_gemmk_gemmm_global_desc), const FloatAB*,
const FloatAB*, decltype(in_e_n_ho_wo_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc), const FloatAB*,
const FloatAB*, decltype(out_k_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc), FloatC*,
FloatC*, integral_constant<bool, true>,
integral_constant<bool, true>, integral_constant<bool, false>>;
integral_constant<bool, false>>;
launch_kernel(kernel, launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
wei_gemmk_gemmm_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_gemmk_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_gemmm_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel = run_gridwise_operation<gridwise_gemm,
run_gridwise_operation<gridwise_gemm, decltype(wei_e_k_global_desc),
decltype(wei_gemmk_gemmm_global_desc), const FloatAB*,
const FloatAB*, decltype(in_e_n_ho_wo_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc), const FloatAB*,
const FloatAB*, decltype(out_k_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc), FloatC*,
FloatC*, integral_constant<bool, false>,
integral_constant<bool, false>, integral_constant<bool, true>>;
integral_constant<bool, true>>;
launch_kernel(kernel, launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
wei_gemmk_gemmm_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_gemmk_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_gemmm_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
else else
{ {
const auto kernel = const auto kernel = run_gridwise_operation<gridwise_gemm,
run_gridwise_operation<gridwise_gemm, decltype(wei_e_k_global_desc),
decltype(wei_gemmk_gemmm_global_desc), const FloatAB*,
const FloatAB*, decltype(in_e_n_ho_wo_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc), const FloatAB*,
const FloatAB*, decltype(out_k_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc), FloatC*,
FloatC*, integral_constant<bool, false>,
integral_constant<bool, false>, integral_constant<bool, false>>;
integral_constant<bool, false>>;
launch_kernel(kernel, launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
wei_gemmk_gemmm_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_gemmk_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_gemmm_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
...@@ -340,7 +340,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -340,7 +340,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc, wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) / out_n_k0_ho_wo_k1_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time; (std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
......
#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_OUTPAD_HPP
#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_OUTPAD_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_v2.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t KPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t EPerBlock,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_E_K,
index_t ABlockTransferSrcScalarPerVector_E,
index_t ABlockTransferDstScalarPerVector_K,
index_t BThreadTransferSrcScalarPerVector_W,
index_t CThreadTransferDstScalarPerVector_W>
struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
{
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global,
FloatC* __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo;
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW
<< std::endl;
std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW
<< std::endl;
// weight tensor
const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(
make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(N),
make_pass_through_transform(Hop),
make_pass_through_transform(Wop)),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor
const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pad_transform(Ho, 0, OutRightPadH),
make_pad_transform(Wo, 0, OutRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto E = C * Y * X;
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
(E % EPerBlock) == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_e_k_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
constexpr auto b_e_n_ho_wo_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc),
decltype(out_k_n_hop_wop_global_desc),
KPerBlock,
HoPerBlock,
WoPerBlock,
EPerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
Sequence<1, 0>,
Sequence<1, 0>,
0,
ABlockTransferSrcScalarPerVector_E,
ABlockTransferDstScalarPerVector_K,
false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 3, 1>,
3,
BThreadTransferSrcScalarPerVector_W,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<0, 2, 3, 1>,
0,
CThreadTransferDstScalarPerVector_W,
decltype(a_e_k_global_iterator_hacks),
decltype(b_e_n_ho_wo_global_iterator_hacks),
decltype(c_k_n_ho_wo_global_tensor_iterator_hacks),
decltype(a_e_k_global_move_slice_window_iterator_hack),
decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>;
const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
<< std::endl;
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k0_ho_wo_k1_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
};
} // namespace ck
#endif
...@@ -134,9 +134,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -134,9 +134,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto KPerThreadSubC = 4; constexpr auto KPerThreadSubC = 4;
constexpr auto HoPerThreadSubC = 2;
constexpr auto WoPerThreadSubC = 2;
static_assert(KPerThread % KPerThreadSubC == 0, ""); static_assert(KPerThread % KPerThreadSubC == 0, "");
static_assert(HPerThread % 2 == 0, ""); static_assert(HPerThread % HoPerThreadSubC == 0, "");
static_assert(WPerThread % 2 == 0, ""); static_assert(WPerThread % WoPerThreadSubC == 0, "");
// thread A, B for GEMM // thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
...@@ -158,7 +161,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -158,7 +161,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<decltype(a_thread_mtx), constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<decltype(a_thread_mtx),
decltype(b_thread_mtx), decltype(b_thread_mtx),
decltype(c_thread_mtx)>{}; decltype(c_thread_mtx),
HoPerThreadSubC,
WoPerThreadSubC>{};
// loop over k // loop over k
#pragma unroll #pragma unroll
for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop) for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop)
...@@ -171,10 +176,11 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -171,10 +176,11 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
mMyThreadOffsetA, mMyThreadOffsetA,
p_a_thread); p_a_thread);
for(index_t h_begin = 0; h_begin < HPerThread; h_begin += 2) #pragma unroll
for(index_t h_begin = 0; h_begin < HPerThread; h_begin += HoPerThreadSubC)
{ {
#pragma unroll
for(index_t w_begin = 0; w_begin < WPerThread; w_begin += 2) for(index_t w_begin = 0; w_begin < WPerThread; w_begin += WoPerThreadSubC)
{ {
threadwise_gemm.Run(p_a_thread, threadwise_gemm.Run(p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple( p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(
......
...@@ -37,6 +37,8 @@ __device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread ...@@ -37,6 +37,8 @@ __device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread
template <typename ADesc, template <typename ADesc,
typename BDesc, typename BDesc,
typename CDesc, typename CDesc,
index_t H,
index_t W,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
...@@ -54,11 +56,6 @@ struct ThreadwiseGemm_km_kn_mn_v3 ...@@ -54,11 +56,6 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
// constexpr auto H = BDesc{}.GetLength(I2);
// constexpr auto W = BDesc{}.GetLength(I3);
constexpr auto H = 2;
constexpr auto W = 2;
constexpr auto E = ADesc{}.GetLength(I0); constexpr auto E = ADesc{}.GetLength(I0);
constexpr auto K = ADesc{}.GetLength(I1); constexpr auto K = ADesc{}.GetLength(I1);
......
...@@ -59,7 +59,26 @@ __llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, ...@@ -59,7 +59,26 @@ __llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
// half
__device__ half_t
__llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
__device__ half2_t
__llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
__device__ half4_t
__llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
// float
__device__ float __device__ float
__llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, __llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset, index_t voffset,
...@@ -114,6 +133,28 @@ __llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, ...@@ -114,6 +133,28 @@ __llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
// half
__device__ void
__llvm_amdgcn_raw_buffer_store_fp16(half_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
__device__ void
__llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
__device__ void
__llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
// float
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_fp32(float vdata, __llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -142,7 +183,13 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -142,7 +183,13 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
index_t src_wave_addr_offset) index_t src_wave_addr_offset)
{ {
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)), (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half2_t>::value && (N == 1)) ||
(is_same<T, half4_t>::value && (N == 1)) ||
(is_same<T, half8_t>::value && (N == 1)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32x2_t>::value && (N == 1)) ||
(is_same<T, int32x4_t>::value && (N == 1)),
"wrong! not implemented"); "wrong! not implemented");
if constexpr(is_same<T, float>::value) if constexpr(is_same<T, float>::value)
...@@ -169,8 +216,63 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -169,8 +216,63 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<1>{}) =
src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(float), 0); __llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
0);
return tmp.Vector();
}
}
else if constexpr(is_same<T, half_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return __llvm_amdgcn_raw_buffer_load_fp16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half2_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half4_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, half8_t>::value)
{
if constexpr(N == 1)
{
vector_type<half_t, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) =
__llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(half_t),
0);
return tmp.Vector(); return tmp.Vector();
} }
...@@ -199,12 +301,31 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -199,12 +301,31 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4( tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_i32x4( tmp.Vectors(Number<4>{})(Number<1>{}) =
src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(int32_t), 0); __llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
0);
return tmp.Vector(); return tmp.Vector();
} }
} }
else if constexpr(is_same<T, int32x2_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
else if constexpr(is_same<T, int32x4_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
}
} }
template <typename T, index_t N> template <typename T, index_t N>
...@@ -213,10 +334,12 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -213,10 +334,12 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
index_t dst_thread_addr_offset, index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset) index_t dst_wave_addr_offset)
{ {
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || static_assert(
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4)), (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
"wrong! not implemented"); (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
if constexpr(is_same<T, float>::value) if constexpr(is_same<T, float>::value)
{ {
...@@ -298,6 +421,65 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -298,6 +421,65 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
else if constexpr(N == 8)
{
__llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 16)
{
__llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, half_t>::value)
{
if constexpr(N == 1)
{
__llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp{src_thread_data};
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
0);
}
} }
} }
......
...@@ -166,6 +166,53 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a, ...@@ -166,6 +166,53 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
"3"(c3)); "3"(c3));
} }
__device__ void amd_assembly_outer_product_1x4(half8_t a,
half8_t b0,
half8_t b1,
half8_t b2,
half8_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a);
const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0);
const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1);
const half4_t* p_b2_half4 = reinterpret_cast<const half4_t*>(&b2);
const half4_t* p_b3_half4 = reinterpret_cast<const half4_t*>(&b3);
amd_assembly_outer_product_1x4(
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
amd_assembly_outer_product_1x4(
p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
}
__device__ void amd_assembly_outer_product_1x4(half16_t a,
half16_t b0,
half16_t b1,
half16_t b2,
half16_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a);
const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0);
const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1);
const half8_t* p_b2_half8 = reinterpret_cast<const half8_t*>(&b2);
const half8_t* p_b3_half8 = reinterpret_cast<const half8_t*>(&b3);
amd_assembly_outer_product_1x4(
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
amd_assembly_outer_product_1x4(
p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
}
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1) // c1 += inner_product(a, b1)
__device__ void __device__ void
...@@ -215,5 +262,82 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, ...@@ -215,5 +262,82 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
#endif #endif
} }
__device__ void amd_assembly_outer_product_1x4(int8x8_t a,
int8x8_t b0,
int8x8_t b1,
int8x8_t b2,
int8x8_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
const int8x4_t* p_a_int8x4_t = reinterpret_cast<const int8x4_t*>(&a);
const int8x4_t* p_b0_int8x4_t = reinterpret_cast<const int8x4_t*>(&b0);
const int8x4_t* p_b1_int8x4_t = reinterpret_cast<const int8x4_t*>(&b1);
const int8x4_t* p_b2_int8x4_t = reinterpret_cast<const int8x4_t*>(&b2);
const int8x4_t* p_b3_int8x4_t = reinterpret_cast<const int8x4_t*>(&b3);
amd_assembly_outer_product_1x4(p_a_int8x4_t[0],
p_b0_int8x4_t[0],
p_b1_int8x4_t[0],
p_b2_int8x4_t[0],
p_b3_int8x4_t[0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(p_a_int8x4_t[1],
p_b0_int8x4_t[1],
p_b1_int8x4_t[1],
p_b2_int8x4_t[1],
p_b3_int8x4_t[1],
c0,
c1,
c2,
c3);
}
__device__ void amd_assembly_outer_product_1x4(int8x16_t a,
int8x16_t b0,
int8x16_t b1,
int8x16_t b2,
int8x16_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
const int8x8_t* p_a_int8x8_t = reinterpret_cast<const int8x8_t*>(&a);
const int8x8_t* p_b0_int8x8_t = reinterpret_cast<const int8x8_t*>(&b0);
const int8x8_t* p_b1_int8x8_t = reinterpret_cast<const int8x8_t*>(&b1);
const int8x8_t* p_b2_int8x8_t = reinterpret_cast<const int8x8_t*>(&b2);
const int8x8_t* p_b3_int8x8_t = reinterpret_cast<const int8x8_t*>(&b3);
amd_assembly_outer_product_1x4(p_a_int8x8_t[0],
p_b0_int8x8_t[0],
p_b1_int8x8_t[0],
p_b2_int8x8_t[0],
p_b3_int8x8_t[0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(p_a_int8x8_t[1],
p_b0_int8x8_t[1],
p_b1_int8x8_t[1],
p_b2_int8x8_t[1],
p_b3_int8x8_t[1],
c0,
c1,
c2,
c3);
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -168,6 +168,84 @@ struct vector_type<T, 8> ...@@ -168,6 +168,84 @@ struct vector_type<T, 8>
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; } __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
}; };
template <typename T>
struct vector_type<T, 16>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
using type = d16_t;
union
{
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d16_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 16; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d16_; }
__host__ __device__ constexpr auto& Vector() { return data_.d16_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x16_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x16_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x16_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x8_; }
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<16>) const { return data_.d16x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x16_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x8_; }
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x4_; }
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x2_; }
__host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; }
};
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
// fp16
using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
template <> template <>
struct vector_type<int8_t, 2> struct vector_type<int8_t, 2>
{ {
...@@ -250,31 +328,118 @@ struct vector_type<int8_t, 4> ...@@ -250,31 +328,118 @@ struct vector_type<int8_t, 4>
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; } __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; }
}; };
// fp32 template <>
using float2_t = typename vector_type<float, 2>::type; struct vector_type<int8_t, 8>
using float4_t = typename vector_type<float, 4>::type; {
using float8_t = typename vector_type<float, 8>::type; using d1_t = int8_t;
typedef int16_t d2_t;
typedef int32_t d4_t;
typedef int32x2_t d8_t;
// fp16 using type = d8_t;
using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
// bfp16 union
using ushort2_t = typename vector_type<ushort, 2>::type; {
using ushort4_t = typename vector_type<ushort, 4>::type; d8_t d8_;
using ushort8_t = typename vector_type<ushort, 8>::type; StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
// i32 __host__ __device__ constexpr vector_type() : data_{type{0}} {}
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type; __host__ __device__ constexpr vector_type(type v) : data_{v} {}
using int32x8_t = typename vector_type<int32_t, 8>::type;
__host__ __device__ static constexpr index_t Size() { return 8; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d8_; }
__host__ __device__ constexpr auto& Vector() { return data_.d8_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x8_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x8_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x8_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x8_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x4_; }
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x2_; }
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
};
template <>
struct vector_type<int8_t, 16>
{
using d1_t = int8_t;
typedef int16_t d2_t;
typedef int32_t d4_t;
typedef int32x2_t d8_t;
typedef int32x4_t d16_t;
using type = d16_t;
union
{
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d8_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 16; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d16_; }
__host__ __device__ constexpr auto& Vector() { return data_.d16_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x16_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x16_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x16_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x8_; }
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<16>) const { return data_.d16x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x16_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x8_; }
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x4_; }
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x2_; }
__host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; }
};
// i8 // i8
// hack for int8x4_t, because compiler does not have native support for int8x4_t // hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t // int8x4_t is defined as int32_t
using int8x4_t = typename vector_type<int8_t, 4>::type; using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
// data type conversion // data type conversion
template <typename T> template <typename T>
...@@ -339,6 +504,34 @@ struct inner_product_with_conversion ...@@ -339,6 +504,34 @@ struct inner_product_with_conversion
return acc; return acc;
} }
__device__ T operator()(int8x8_t a, int8x8_t b) const
{
const vector_type<int8_t, 8> a_vector{a};
const vector_type<int8_t, 8> b_vector{b};
T acc = 0;
static_for<0, 8, 1>{}([&](auto i) {
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
});
return acc;
}
__device__ T operator()(int8x16_t a, int8x16_t b) const
{
const vector_type<int8_t, 16> a_vector{a};
const vector_type<int8_t, 16> b_vector{b};
T acc = 0;
static_for<0, 16, 1>{}([&](auto i) {
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
});
return acc;
}
}; };
} // namespace ck } // namespace ck
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp"
template <class TInWei, template <class TInWei,
ck::index_t InWeiVectorSize, ck::index_t InWeiVectorSize,
...@@ -57,6 +58,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -57,6 +58,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr auto C0 = C / Number<InWeiVectorSize>{}; constexpr auto C0 = C / Number<InWeiVectorSize>{};
constexpr auto C1 = Number<InWeiVectorSize>{}; constexpr auto C1 = Number<InWeiVectorSize>{};
constexpr auto K0 = K / Number<InWeiVectorSize>{};
constexpr auto K1 = Number<InWeiVectorSize>{};
#if 0 #if 0
// run-time variables // run-time variables
const auto in_n_c_hi_wi_desc = const auto in_n_c_hi_wi_desc =
...@@ -76,19 +80,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -76,19 +80,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi));
const auto wei_k_c0_y_x_desc = const auto wei_k_c0_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
const auto out_n_k_ho_wo_desc = const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor( Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{}))); make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{})));
Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor( Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{}))); make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{})));
Tensor<TOut> out_n_k0_ho_wo_k1(make_HostTensorDescriptor(
make_native_tensor_descriptor_packed(Sequence<N, K0, Ho, Wo, K1>{})));
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) { auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) = in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
...@@ -106,13 +112,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -106,13 +112,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
#if 1
// cdata = 64, BlockSize = 64, 16x8x32x4 // cdata = 64, BlockSize = 64, 16x8x32x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = K;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32;
constexpr index_t EPerBlock = C0;
constexpr index_t KPerThread = KPerBlock;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
#else
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8; constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr index_t EPerBlock = 4; constexpr index_t EPerBlock = 1;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
...@@ -127,32 +158,28 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -127,32 +158,28 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
constexpr index_t CThreadTransferDstScalarPerVector_W = 1; constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
#endif
constexpr auto conv_driver = constexpr auto conv_driver =
#if 0
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad< DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
BlockSize, #else
typename vector_type<TInWei, InWeiVectorSize>::type, DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad<
TAcc, #endif
TOut, BlockSize,
KPerBlock, typename vector_type<TInWei, InWeiVectorSize>::type, TAcc, TOut, KPerBlock,
HoPerBlock, HoPerBlock, WoPerBlock, EPerBlock, KPerThread, HoPerThread, WoPerThread,
WoPerBlock, EPerThread, ABlockTransferThreadSliceLengths_E_K,
EPerBlock, ABlockTransferThreadClusterLengths_E_K, ABlockTransferSrcScalarPerVector_E,
KPerThread, ABlockTransferDstScalarPerVector_K, BThreadTransferSrcScalarPerVector_W,
HoPerThread, CThreadTransferDstScalarPerVector_W > {};
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
ABlockTransferSrcScalarPerVector_E,
ABlockTransferDstScalarPerVector_K,
BThreadTransferSrcScalarPerVector_W,
CThreadTransferDstScalarPerVector_W>{};
conv_driver.Run(wei_k_c0_y_x_desc, conv_driver.Run(wei_k_c0_y_x_desc,
in_n_c0_hi_wi_desc, in_n_c0_hi_wi_desc,
out_n_k_ho_wo_desc, out_n_k0_ho_wo_k1_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
...@@ -163,5 +190,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -163,5 +190,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()), in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer())); static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); out_n_k_ho_wo_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_n_k_ho_wo(n, k, ho, wo) =
out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
};
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)();
} }
...@@ -48,8 +48,8 @@ int main(int argc, char* argv[]) ...@@ -48,8 +48,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
...@@ -62,8 +62,8 @@ int main(int argc, char* argv[]) ...@@ -62,8 +62,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
...@@ -642,7 +642,7 @@ int main(int argc, char* argv[]) ...@@ -642,7 +642,7 @@ int main(int argc, char* argv[])
using out_data_t = int8_t; using out_data_t = int8_t;
#elif 1 #elif 1
using in_data_t = int8_t; using in_data_t = int8_t;
constexpr index_t in_vector_size = 4; constexpr index_t in_vector_size = 16;
using acc_data_t = int32_t; using acc_data_t = int32_t;
using out_data_t = int8_t; using out_data_t = int8_t;
#endif #endif
...@@ -741,7 +741,7 @@ int main(int argc, char* argv[]) ...@@ -741,7 +741,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment