Commit 4b456610 authored by root's avatar root
Browse files

merge

parents 1014e6c9 4d93ce0e
...@@ -13,8 +13,9 @@ namespace ck { ...@@ -13,8 +13,9 @@ namespace ck {
// GemmN = N * Ho * Wo // GemmN = N * Ho * Wo
// GemmK = C * Y * X // GemmK = C * Y * X
template <index_t BlockSize, template <index_t BlockSize,
typename Float, typename FloatAB,
typename AccFloat, typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmKPerBlock,
...@@ -50,9 +51,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -50,9 +51,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const Float* __restrict__ p_wei_global, const FloatAB* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -180,10 +181,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -180,10 +181,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{}));
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1< using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize, BlockSize,
Float, FloatAB,
AccFloat, FloatAcc,
FloatC,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
...@@ -230,7 +232,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -230,7 +232,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
#if 1 // pass tensor descriptors by their reference #if 1 // pass tensor descriptors by value
index_t nrepeat = 100; index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
...@@ -247,12 +249,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -247,12 +249,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -275,12 +277,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -275,12 +277,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -303,12 +305,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -303,12 +305,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -331,12 +333,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -331,12 +333,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -368,7 +370,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -368,7 +370,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl; << std::endl;
} }
#elif 1 // pass tensor descriptors by their pointers #elif 1 // pass tensor descriptors by pointers
using ADesc = decltype(wei_gemmk_gemmm_global_desc); using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc); using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
...@@ -397,13 +399,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -397,13 +399,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{ {
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -429,13 +431,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -429,13 +431,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{ {
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -461,13 +463,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -461,13 +463,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{ {
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -493,13 +495,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -493,13 +495,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{ {
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -564,11 +566,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -564,11 +566,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -591,11 +593,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -591,11 +593,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -618,11 +620,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -618,11 +620,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -645,11 +647,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -645,11 +647,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -690,8 +692,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad ...@@ -690,8 +692,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
// GemmN = N * Ho * Wo // GemmN = N * Ho * Wo
// GemmK = C * Y * X // GemmK = C * Y * X
template <index_t BlockSize, template <index_t BlockSize,
typename Float, typename FloatAB,
typename AccFloat, typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmKPerBlock,
...@@ -727,9 +730,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -727,9 +730,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const Float* __restrict__ p_wei_global, const FloatAB* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -851,10 +854,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -851,10 +854,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{}));
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1< using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize, BlockSize,
Float, FloatAB,
AccFloat, FloatAcc,
FloatC,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
...@@ -901,7 +905,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -901,7 +905,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
#if 1 // pass tensor descriptors by their reference #if 1 // pass tensor descriptors by value
index_t nrepeat = 100; index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
...@@ -918,12 +922,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -918,12 +922,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -946,12 +950,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -946,12 +950,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -974,12 +978,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -974,12 +978,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1002,12 +1006,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -1002,12 +1006,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1039,7 +1043,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -1039,7 +1043,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl; << std::endl;
} }
#elif 1 // pass tensor descriptors by their pointers #elif 1 // pass tensor descriptors by pointers
using ADesc = decltype(wei_gemmk_gemmm_global_desc); using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc); using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
...@@ -1069,12 +1073,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -1069,12 +1073,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1101,12 +1105,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -1101,12 +1105,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1133,12 +1137,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -1133,12 +1137,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1165,12 +1169,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -1165,12 +1169,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1235,11 +1239,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -1235,11 +1239,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1262,11 +1266,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -1262,11 +1266,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1289,11 +1293,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -1289,11 +1293,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1316,11 +1320,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -1316,11 +1320,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1358,8 +1362,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad ...@@ -1358,8 +1362,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
}; };
template <index_t BlockSize, template <index_t BlockSize,
typename Float, typename FloatAB,
typename AccFloat, typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmKPerBlock,
...@@ -1395,9 +1400,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1395,9 +1400,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const Float* __restrict__ p_wei_global, const FloatAB* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -1508,10 +1513,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1508,10 +1513,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{}));
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1< using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize, BlockSize,
Float, FloatAB,
AccFloat, FloatAcc,
FloatC,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
...@@ -1558,7 +1564,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1558,7 +1564,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
#if 1 // pass tensor descriptors by their reference #if 1 // pass tensor descriptors by value
index_t nrepeat = 100; index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
...@@ -1575,12 +1581,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1575,12 +1581,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1603,12 +1609,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1603,12 +1609,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1631,12 +1637,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1631,12 +1637,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1659,12 +1665,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1659,12 +1665,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1696,7 +1702,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1696,7 +1702,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl; << std::endl;
} }
#elif 1 // pass tensor descriptors by their pointers #elif 1 // pass tensor descriptors by pointers
using ADesc = decltype(wei_gemmk_gemmm_global_desc); using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc); using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
...@@ -1726,12 +1732,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1726,12 +1732,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1758,12 +1764,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1758,12 +1764,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1790,12 +1796,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1790,12 +1796,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1822,12 +1828,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1822,12 +1828,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1892,11 +1898,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1892,11 +1898,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1919,11 +1925,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1919,11 +1925,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1946,11 +1952,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1946,11 +1952,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1973,11 +1979,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 ...@@ -1973,11 +1979,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
......
...@@ -13,8 +13,9 @@ namespace ck { ...@@ -13,8 +13,9 @@ namespace ck {
// GemmN = N * Ho * Wo // GemmN = N * Ho * Wo
// GemmK = Y * X * C // GemmK = Y * X * C
template <index_t BlockSize, template <index_t BlockSize,
typename Float, typename FloatAB,
typename AccFloat, typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmKPerBlock,
...@@ -50,9 +51,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -50,9 +51,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const Float* __restrict__ p_wei_global, const FloatAB* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -179,10 +180,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -179,10 +180,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{}));
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1< using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize, BlockSize,
Float, FloatAB,
AccFloat, FloatAcc,
FloatC,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
...@@ -231,7 +233,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -231,7 +233,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize);
#if 1 // pass tensor descriptors by their reference #if 1 // pass tensor descriptors by value
index_t nrepeat = 100; index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
...@@ -248,12 +250,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -248,12 +250,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -276,12 +278,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -276,12 +278,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -304,12 +306,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -304,12 +306,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -332,12 +334,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -332,12 +334,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -367,7 +369,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -367,7 +369,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl; << std::endl;
} }
#elif 1 // pass tensor descriptors by their pointers #elif 1 // pass tensor descriptors by pointers
using ADesc = decltype(wei_gemmk_gemmm_global_desc); using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc); using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
...@@ -397,12 +399,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -397,12 +399,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -429,12 +431,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -429,12 +431,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -461,12 +463,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -461,12 +463,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -493,12 +495,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -493,12 +495,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -561,11 +563,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -561,11 +563,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -588,11 +590,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -588,11 +590,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -615,11 +617,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -615,11 +617,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -642,11 +644,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -642,11 +644,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -682,8 +684,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad ...@@ -682,8 +684,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
}; };
template <index_t BlockSize, template <index_t BlockSize,
typename Float, typename FloatAB,
typename AccFloat, typename FloatAcc,
typename FloatC,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmKPerBlock,
...@@ -719,9 +722,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -719,9 +722,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const Float* __restrict__ p_wei_global, const FloatAB* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -831,10 +834,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -831,10 +834,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}));
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1< using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<
BlockSize, BlockSize,
Float, FloatAB,
AccFloat, FloatAcc,
FloatC,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
...@@ -883,7 +887,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -883,7 +887,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize);
#if 1 // pass tensor descriptors by their reference #if 1 // pass tensor descriptors by value
index_t nrepeat = 100; index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
...@@ -900,12 +904,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -900,12 +904,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -928,12 +932,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -928,12 +932,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -956,12 +960,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -956,12 +960,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -984,12 +988,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -984,12 +988,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc), decltype(in_gemmk_gemmn_global_desc),
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1019,7 +1023,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -1019,7 +1023,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl; << std::endl;
} }
#elif 1 // pass tensor descriptors by their pointers #elif 1 // pass tensor descriptors by pointers
using ADesc = decltype(wei_gemmk_gemmm_global_desc); using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc); using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
...@@ -1049,12 +1053,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -1049,12 +1053,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1081,12 +1085,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -1081,12 +1085,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1113,12 +1117,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -1113,12 +1117,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1145,12 +1149,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -1145,12 +1149,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*, decltype(wei_gemmk_gemmm_global_desc)*,
const Float*, const FloatAB*,
decltype(in_gemmk_gemmn_global_desc)*, decltype(in_gemmk_gemmn_global_desc)*,
const Float*, const FloatAB*,
decltype( decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1213,11 +1217,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -1213,11 +1217,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1240,11 +1244,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -1240,11 +1244,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -1267,11 +1271,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -1267,11 +1271,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -1294,11 +1298,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 ...@@ -1294,11 +1298,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
const Float*, const FloatAB*,
const void*, const void*,
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
......
...@@ -219,6 +219,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -219,6 +219,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
KernelTimer timer; KernelTimer timer;
timer.Start(); timer.Start();
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
<< std::endl;
for(index_t j = 0; j < nrepeat; ++j) for(index_t j = 0; j < nrepeat; ++j)
{ {
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
template <typename GridwiseOp, typename... Xs> template <typename GridwiseOp, typename... Xs>
__global__ void __global__ void
#if 0 #if 1
__launch_bounds__(256, 2) __launch_bounds__(64, 2)
#endif #endif
run_gridwise_operation(Xs... xs) run_gridwise_operation(Xs... xs)
{ {
GridwiseOp{}.Run(xs...); GridwiseOp{}.Run(xs...);
} }
......
...@@ -154,6 +154,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -154,6 +154,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
decltype(b_thread_mtx), decltype(b_thread_mtx),
decltype(c_thread_mtx)>{}; decltype(c_thread_mtx)>{};
// loop over k // loop over k
#pragma unroll
for(index_t cyx_begin = 0; cyx_begin < CYXPerBlock; cyx_begin += CYXPerThreadLoop) for(index_t cyx_begin = 0; cyx_begin < CYXPerBlock; cyx_begin += CYXPerThreadLoop)
{ {
a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(cyx_begin, 0)) + a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(cyx_begin, 0)) +
......
...@@ -12,8 +12,9 @@ ...@@ -12,8 +12,9 @@
namespace ck { namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename Float, typename FloatAB,
typename AccFloat, typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc, typename AGlobalDesc,
typename BGlobalDesc, typename BGlobalDesc,
...@@ -52,7 +53,7 @@ template <index_t BlockSize, ...@@ -52,7 +53,7 @@ template <index_t BlockSize,
typename CGlobalIteratorHacks, typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_mn_v1 struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -78,17 +79,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -78,17 +79,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float); return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc, const BGlobalDesc& b_k_n_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
Float* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
...@@ -144,8 +145,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -144,8 +145,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
Float, FloatAB,
Float, FloatAB,
decltype(a_k_m_global_desc), decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -173,8 +174,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -173,8 +174,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
BBlockTransferThreadSliceLengths_K_N, BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N, BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
Float, FloatAB,
Float, FloatAB,
decltype(b_k_n_global_desc), decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc), decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -235,11 +236,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -235,11 +236,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block; FloatAB* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()]; FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread); threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
...@@ -269,11 +270,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -269,11 +270,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
Float* p_a_block_even = p_a_block_double; FloatAB* p_a_block_even = p_a_block_double;
Float* p_b_block_even = p_b_block_double; FloatAB* p_b_block_even = p_b_block_double;
Float* p_a_block_odd = p_a_block_double + a_block_space_size; FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size; FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
index_t k_block_data_begin = 0; index_t k_block_data_begin = 0;
...@@ -400,8 +401,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -400,8 +401,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{})); Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat, FloatAcc,
Float, FloatC,
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc), decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>, Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
...@@ -429,17 +430,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -429,17 +430,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// pass tensor descriptor by reference // pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc, const BGlobalDesc& b_k_n_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ Float p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
Run(a_k_m_global_desc, Run(a_k_m_global_desc,
p_a_global, p_a_global,
...@@ -452,14 +453,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -452,14 +453,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
// pass tensor descriptors by their pointers // pass tensor descriptors by pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_k_m_global_desc, __device__ void Run(const AGlobalDesc* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_k_n_global_desc, const BGlobalDesc* p_b_k_n_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc, const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
...@@ -480,11 +481,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -480,11 +481,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// pass tensor descriptors by void* // pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_k_m_global_desc, __device__ void Run(const void* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const void* p_b_k_n_global_desc, const void* p_b_k_n_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const void* p_c_m0_m1_n0_n1_global_desc, const void* p_c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
......
...@@ -537,12 +537,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -537,12 +537,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_cyx_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_cyx_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<CYX>{}, Number<K>{}), max_lds_align); make_tuple(Number<CYX>{}, Number<K>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_cyx_k_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_cyx_k_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(Float); return a_block_space_size * sizeof(Float);
} }
......
...@@ -181,7 +181,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -181,7 +181,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx + src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector); i * dst_scalar_step_in_vector);
dst_vector.Scalars()(i) = p_src[Number<src_offset>{}]; dst_vector.Scalars()(i) = type_convert<DstData>{}(p_src[Number<src_offset>{}]);
}); });
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
......
...@@ -161,19 +161,7 @@ struct ThreadwiseGemm_km_kn_mn_v1 ...@@ -161,19 +161,7 @@ struct ThreadwiseGemm_km_kn_mn_v1
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) __device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
constexpr bool has_amd_asm = is_same<FloatC, float>{} &&
((is_same<FloatA, float>{} && is_same<FloatB, float>{}) ||
(is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) ||
(is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{}));
if constexpr(has_amd_asm)
{
Run_amd_asm(p_a, p_b, p_c); Run_amd_asm(p_a, p_b, p_c);
}
else
{
Run_source(p_a, p_b, p_c);
}
#else #else
Run_source(p_a, p_b, p_c); Run_source(p_a, p_b, p_c);
#endif #endif
......
...@@ -31,6 +31,35 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_siz ...@@ -31,6 +31,35 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_siz
return wave_buffer_resource.data; return wave_buffer_resource.data;
} }
// load
__device__ int8_t
__llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
__device__ int16_t
__llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
__device__ int32_t
__llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
__device__ int32x2_t
__llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
__device__ int32x4_t
__llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
__device__ float __device__ float
__llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, __llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset, index_t voffset,
...@@ -49,6 +78,42 @@ __llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, ...@@ -49,6 +78,42 @@ __llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
// store
__device__ void
__llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
__device__ void
__llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
__device__ void
__llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
__device__ void
__llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
__device__ void
__llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_fp32(float vdata, __llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -70,213 +135,228 @@ __llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, ...@@ -70,213 +135,228 @@ __llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
// buffer_load requires: template <typename T, index_t N>
// 1) p_src_wave must be in global memory space __device__ typename vector_type<T, N>::type
// 2) p_src_wave to be a wavewise pointer. amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
// It is user's responsibility to make sure that is true. index_t src_thread_addr_offset,
template <typename T, index_t VectorSize> index_t src_wave_addr_offset)
__device__ typename vector_type<T, VectorSize>::type
amd_buffer_load_v2(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_elemenst_space);
// buffer_store requires:
// 1) p_dst_wave must be global memory
// 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize>
__device__ void amd_buffer_store_v2(const typename vector_type<T, VectorSize>::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_data_offset,
const bool dst_thread_data_valid,
const index_t dst_data_range);
template <>
__device__ float amd_buffer_load_v2<float, 1>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{ {
const int32x4_t src_wave_buffer_resource = static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
make_wave_buffer_resource(p_src_wave, src_data_range); (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
if constexpr(is_same<T, float>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp32( return __llvm_amdgcn_raw_buffer_load_fp32(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else }
float tmp = else if constexpr(N == 2)
__llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, src_thread_addr_offset, 0, 0); {
return src_thread_data_valid ? tmp : float(0);
#endif
}
template <>
__device__ float2_t amd_buffer_load_v2<float, 2>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_raw_buffer_load_fp32x2( return __llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else }
float2_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x2( else if constexpr(N == 4)
src_wave_buffer_resource, src_thread_addr_offset, 0, 0); {
return src_thread_data_valid ? tmp : float2_t(0);
#endif
}
template <>
__device__ float4_t amd_buffer_load_v2<float, 4>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_raw_buffer_load_fp32x4( return __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else }
float4_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x4( else if constexpr(N == 8)
src_wave_buffer_resource, src_thread_addr_offset, 0, 0); {
return src_thread_data_valid ? tmp : float4_t(0);
#endif
}
template <>
__device__ float8_t amd_buffer_load_v2<float, 8>(const float* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_data_range)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_data_range);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
vector_type<float, 8> tmp; vector_type<float, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(float), 0);
src_addr_shift + src_thread_addr_offset + 4 * sizeof(float),
0,
0);
return tmp.Vector(); return tmp.Vector();
#else }
vector_type<float, 8> tmp; }
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_i32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return __llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 8)
{
vector_type<int32_t, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset + 4 * sizeof(float), 0, 0); src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(int32_t), 0);
return src_thread_data_valid ? tmp.Vector() : float8_t(0); return tmp.Vector();
#endif }
}
} }
template <> template <typename T, index_t N>
__device__ void amd_buffer_store_v2<float, 1>(const float src_thread_data, __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type src_thread_data,
float* p_dst_wave, int32x4_t dst_wave_buffer_resource,
const index_t dst_thread_data_offset, index_t dst_thread_addr_offset,
const bool dst_thread_data_valid, index_t dst_wave_addr_offset)
const index_t dst_data_range)
{ {
const int32x4_t dst_wave_buffer_resource = static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
make_wave_buffer_resource(p_dst_wave, dst_data_range); (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4)),
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); "wrong! not implemented");
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_raw_buffer_store_fp32( if constexpr(is_same<T, float>::value)
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0, 0); {
#else if constexpr(N == 1)
if(dst_thread_data_valid) {
__llvm_amdgcn_raw_buffer_store_fp32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{ {
__llvm_amdgcn_buffer_store_fp32( __llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0, 0); dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
__llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, int8_t>::value)
{
if constexpr(N == 1)
{
__llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
} }
#endif
} }
template <> // buffer_load requires:
__device__ void amd_buffer_store_v2<float, 2>(const float2_t src_thread_data, // 1) p_src_wave must be in global memory space
float* p_dst_wave, // 2) p_src_wave to be a wavewise pointer.
const index_t dst_thread_data_offset, // It is user's responsibility to make sure that is true.
const bool dst_thread_data_valid, template <typename T, index_t N>
const index_t dst_data_range) __device__ typename vector_type<T, N>::type amd_buffer_load_v2(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_element_space)
{ {
const int32x4_t dst_wave_buffer_resource = const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_data_range); make_wave_buffer_resource(p_src_wave, src_element_space);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_raw_buffer_store_fp32x2( return amd_buffer_load_impl_v2<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else #else
if(dst_thread_data_valid) using vector_t = typename vector_type<T, N>::type;
{
__llvm_amdgcn_raw_buffer_store_fp32x2( vector_t tmp =
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0, 0); amd_buffer_load_impl_v2<T, N>(src_wave_buffer_resource, src_thread_addr_offset, 0);
}
return src_thread_data_valid ? tmp : vector_t(0);
#endif #endif
} }
template <> // buffer_store requires:
__device__ void amd_buffer_store_v2<float, 4>(const float4_t src_thread_data, // 1) p_dst_wave must be global memory
float* p_dst_wave, // 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ void amd_buffer_store_v2(const typename vector_type<T, N>::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_data_offset, const index_t dst_thread_data_offset,
const bool dst_thread_data_valid, const bool dst_thread_data_valid,
const index_t dst_data_range) const index_t dst_element_space)
{ {
const int32x4_t dst_wave_buffer_resource = const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_data_range); make_wave_buffer_resource(p_dst_wave, dst_element_space);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_raw_buffer_store_fp32x4( amd_buffer_store_impl_v2<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else #else
if(dst_thread_data_valid) if(dst_thread_data_valid)
{ {
__llvm_amdgcn_raw_buffer_store_fp32x4( amd_buffer_store_impl_v2<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
} }
#endif #endif
} }
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
namespace ck { namespace ck {
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
{ {
#if CK_USE_AMD_V_FMAC_F32 #if CK_USE_AMD_V_FMAC_F32
...@@ -25,7 +26,10 @@ __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, floa ...@@ -25,7 +26,10 @@ __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, floa
#endif #endif
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4( __device__ void amd_assembly_outer_product_1x4(
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3) float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
{ {
...@@ -50,7 +54,8 @@ __device__ void amd_assembly_outer_product_1x4( ...@@ -50,7 +54,8 @@ __device__ void amd_assembly_outer_product_1x4(
#endif #endif
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void __device__ void
amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1) amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
{ {
...@@ -58,15 +63,12 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo ...@@ -58,15 +63,12 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo
v_dot2_f32_f16 %0, %2, %3, %0\n \ v_dot2_f32_f16 %0, %2, %3, %0\n \
v_dot2_f32_f16 %1, %2, %4, %1\n \ v_dot2_f32_f16 %1, %2, %4, %1\n \
" "
: "=v"(c0), "=v"(c1) // Dest registers : "=v"(c0), "=v"(c1)
: "v"(a), // 1st Src register for 1 half2 registers : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
"v"(b0), // 2nd Src register
"v"(b1),
"0"(c0), // 3rd Src register
"1"(c1));
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void __device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
{ {
...@@ -81,18 +83,21 @@ amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, flo ...@@ -81,18 +83,21 @@ amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, flo
v_dot2_f32_f16 %0, %3, %5, %0\n \ v_dot2_f32_f16 %0, %3, %5, %0\n \
v_dot2_f32_f16 %1, %3, %7, %1\n \ v_dot2_f32_f16 %1, %3, %7, %1\n \
" "
: "=v"(c0), "=v"(c1) // Dest registers : "=v"(c0), "=v"(c1)
: "v"(p_a_half2[0]), : "v"(p_a_half2[0]),
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers "v"(p_a_half2[1]),
"v"(p_b0_half2[0]), "v"(p_b0_half2[0]),
"v"(p_b0_half2[1]), "v"(p_b0_half2[1]),
"v"(p_b1_half2[0]), "v"(p_b1_half2[0]),
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers "v"(p_b1_half2[1]),
"0"(c0), "0"(c0),
"1"(c1)); // 3rd Src Acc registers for 2 half2 registers "1"(c1));
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half2_t a, __device__ void amd_assembly_outer_product_1x4(half2_t a,
half2_t b0, half2_t b0,
half2_t b1, half2_t b1,
...@@ -109,19 +114,14 @@ __device__ void amd_assembly_outer_product_1x4(half2_t a, ...@@ -109,19 +114,14 @@ __device__ void amd_assembly_outer_product_1x4(half2_t a,
v_dot2_f32_f16 %2, %4, %7, %2\n \ v_dot2_f32_f16 %2, %4, %7, %2\n \
v_dot2_f32_f16 %3, %4, %8, %3\n \ v_dot2_f32_f16 %3, %4, %8, %3\n \
" "
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), // 1st Src register for 1 half2 registers : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
"v"(b0), // 2nd Src register
"v"(b1),
"v"(b2),
"v"(b3),
"0"(c0), // 3rd Src register
"1"(c1),
"2"(c2),
"3"(c3));
} }
// outer-product: c[i,j] += inner_product(a[i], b[j]) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half4_t a, __device__ void amd_assembly_outer_product_1x4(half4_t a,
half4_t b0, half4_t b0,
half4_t b1, half4_t b1,
...@@ -149,21 +149,70 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a, ...@@ -149,21 +149,70 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
v_dot2_f32_f16 %2, %5, %11, %2\n \ v_dot2_f32_f16 %2, %5, %11, %2\n \
v_dot2_f32_f16 %3, %5, %13, %3\n \ v_dot2_f32_f16 %3, %5, %13, %3\n \
" "
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(p_a_half2[0]), : "v"(p_a_half2[0]),
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers "v"(p_a_half2[1]),
"v"(p_b0_half2[0]), "v"(p_b0_half2[0]),
"v"(p_b0_half2[1]), "v"(p_b0_half2[1]),
"v"(p_b1_half2[0]), "v"(p_b1_half2[0]),
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers "v"(p_b1_half2[1]),
"v"(p_b2_half2[0]), "v"(p_b2_half2[0]),
"v"(p_b2_half2[1]), "v"(p_b2_half2[1]),
"v"(p_b3_half2[0]), "v"(p_b3_half2[0]),
"v"(p_b3_half2[1]), // 2nd Src registers for 2 half2 registers "v"(p_b3_half2[1]),
"0"(c0), "0"(c0),
"1"(c1), "1"(c1),
"2"(c2), "2"(c2),
"3"(c3)); // 3rd Src Acc registers for 2 half2 registers "3"(c3));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %2, %3, %0\n \
v_dot4_i32_i8 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
#endif
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(int8x4_t a,
int8x4_t b0,
int8x4_t b1,
int8x4_t b2,
int8x4_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %4, %5, %0\n \
v_dot4_i32_i8 %1, %4, %6, %1\n \
v_dot4_i32_i8 %2, %4, %7, %2\n \
v_dot4_i32_i8 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
#else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
c2 = __builtin_amdgcn_sdot4(a, b2, c2, false);
c3 = __builtin_amdgcn_sdot4(a, b3, c3, false);
#endif
} }
} // namespace ck } // namespace ck
......
...@@ -140,10 +140,5 @@ enum InMemoryDataOperation ...@@ -140,10 +140,5 @@ enum InMemoryDataOperation
// index type // index type
using index_t = int32_t; using index_t = int32_t;
typedef int32_t int32x2_t __attribute__((ext_vector_type(2)));
// int32x4_t use by buffer_load and buffer_store llvm intrinsic
typedef int32_t int32x4_t __attribute__((ext_vector_type(4)));
} // namespace ck } // namespace ck
#endif #endif
...@@ -3,172 +3,6 @@ ...@@ -3,172 +3,6 @@
namespace ck { namespace ck {
// For some reason, HIP compiler need this definition to generate optimal ISA
// fp32
typedef float float2_t __attribute__((ext_vector_type(2)));
typedef float float4_t __attribute__((ext_vector_type(4)));
typedef float float8_t __attribute__((ext_vector_type(8)));
typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32)));
// fp16
typedef _Float16 half_t;
typedef _Float16 half2_t __attribute__((ext_vector_type(2)));
typedef _Float16 half4_t __attribute__((ext_vector_type(4)));
typedef _Float16 half8_t __attribute__((ext_vector_type(8)));
// bfp16
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
struct c_vec32_4_t
{
union VecType
{
struct
{
float32_t x;
float32_t y;
float32_t z;
float32_t w;
} s;
float n[128];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
c.s.y = 0;
c.s.z = 0;
c.s.w = 0;
return c;
}
};
struct c_vec32_2_t
{
union VecType
{
struct
{
float32_t x;
float32_t y;
} s;
float n[64];
} l;
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
c.s.y = 0;
return c;
}
};
struct c_vec32_2_2_t
{
union VecType
{
struct
{
c_vec32_2_t x;
c_vec32_2_t y;
} s;
float n[128];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x.l.s.x = 0;
c.s.x.l.s.y = 0;
c.s.y.l.s.x = 0;
c.s.y.l.s.y = 0;
return c;
}
};
struct c_vec32_1_t
{
union VecType
{
struct
{
float32_t x;
} s;
float n[32];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
return c;
}
};
struct c_vec16_1_t
{
union VecType
{
struct
{
float16_t x;
} s;
float n[16];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
return c;
}
};
struct c_vec4_2_t
{
union VecType
{
struct
{
float4_t x;
float4_t y;
} s;
float n[8];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
c.s.y = 0;
return c;
}
};
struct c_vec4_1_t
{
union VecType
{
struct
{
float4_t x;
} s;
float n[4];
};
__host__ __device__ static VecType CreateVecZero()
{
VecType c;
c.s.x = 0;
return c;
}
};
template <typename T, index_t N> template <typename T, index_t N>
struct vector_type; struct vector_type;
...@@ -183,7 +17,9 @@ struct vector_type<T, 1> ...@@ -183,7 +17,9 @@ struct vector_type<T, 1>
StaticallyIndexedArray<T, 1> d1x1_; StaticallyIndexedArray<T, 1> d1x1_;
} data_; } data_;
__host__ __device__ constexpr vector_type() : data_{T{0}} {} __host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 1; } __host__ __device__ static constexpr index_t Size() { return 1; }
...@@ -215,7 +51,9 @@ struct vector_type<T, 2> ...@@ -215,7 +51,9 @@ struct vector_type<T, 2>
StaticallyIndexedArray<d2_t, 1> d2x1_; StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_; } data_;
__host__ __device__ constexpr vector_type() : data_{d2_t{0}} {} __host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 2; } __host__ __device__ static constexpr index_t Size() { return 2; }
...@@ -253,7 +91,9 @@ struct vector_type<T, 4> ...@@ -253,7 +91,9 @@ struct vector_type<T, 4>
StaticallyIndexedArray<d4_t, 1> d4x1_; StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_; } data_;
__host__ __device__ constexpr vector_type() : data_{d4_t{0}} {} __host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 4; } __host__ __device__ static constexpr index_t Size() { return 4; }
...@@ -297,7 +137,9 @@ struct vector_type<T, 8> ...@@ -297,7 +137,9 @@ struct vector_type<T, 8>
StaticallyIndexedArray<d8_t, 1> d8x1_; StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_; } data_;
__host__ __device__ constexpr vector_type() : data_{d8_t{0}} {} __host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 8; } __host__ __device__ static constexpr index_t Size() { return 8; }
...@@ -326,6 +168,114 @@ struct vector_type<T, 8> ...@@ -326,6 +168,114 @@ struct vector_type<T, 8>
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; } __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
}; };
template <>
struct vector_type<int8_t, 2>
{
using d1_t = int8_t;
typedef int16_t d2_t;
using type = d2_t;
union
{
d2_t d2_;
StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 2; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d2_; }
__host__ __device__ constexpr auto& Vector() { return data_.d2_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x2_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x2_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x1_; }
};
template <>
struct vector_type<int8_t, 4>
{
using d1_t = int8_t;
typedef int16_t d2_t;
typedef int32_t d4_t;
using type = d4_t;
union
{
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
__host__ __device__ static constexpr index_t Size() { return 4; }
__host__ __device__ constexpr const auto& Vector() const { return data_.d4_; }
__host__ __device__ constexpr auto& Vector() { return data_.d4_; }
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x4_; }
__host__ __device__ constexpr auto& Scalars() { return data_.d1x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x4_; }
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x2_; }
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x1_; }
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x4_; }
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x2_; }
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; }
};
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
// fp16
using half_t = _Float16;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
// i8
// hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t
using int8x4_t = typename vector_type<int8_t, 4>::type;
// data type conversion // data type conversion
template <typename T> template <typename T>
struct type_convert struct type_convert
...@@ -356,113 +306,37 @@ struct inner_product_with_conversion ...@@ -356,113 +306,37 @@ struct inner_product_with_conversion
{ {
static constexpr auto convert = type_convert<T>(); static constexpr auto convert = type_convert<T>();
__device__ T operator()(float4_t a, float4_t b) const template <typename X, index_t N>
{ __device__ T operator()(typename vector_type<X, N>::type a,
const float* p_a_float = reinterpret_cast<const float*>(&a); typename vector_type<X, N>::type b) const
const float* p_b_float = reinterpret_cast<const float*>(&b);
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_float[v]) * convert(p_b_float[v]);
}
return acc;
}
__device__ T operator()(float2_t a, float2_t b) const
{ {
const float* p_a_float = reinterpret_cast<const float*>(&a); const vector_type<X, N> a_vector{a};
const float* p_b_float = reinterpret_cast<const float*>(&b); const vector_type<X, N> b_vector{b};
T acc = 0; T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_float[v]) * convert(p_b_float[v]);
}
return acc;
}
__device__ T operator()(float a, float b) const { return convert(a) * convert(b); }
__device__ T operator()(half2_t a, half2_t b) const static_for<0, N, 1>{}([&](auto i) {
{ acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a); });
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc; return acc;
} }
__device__ T operator()(half4_t a, half4_t b) const __device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); }
{
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
__device__ T operator()(half8_t a, half8_t b) const // hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t
__device__ T operator()(int8x4_t a, int8x4_t b) const
{ {
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a); const vector_type<int8_t, 4> a_vector{a};
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b); const vector_type<int8_t, 4> b_vector{b};
T acc = 0; T acc = 0;
for(index_t v = 0; v < 8; ++v)
{
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
}
return acc;
}
__device__ T operator()(ushort2_t a, ushort2_t b) const
{
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
T acc = 0;
for(index_t v = 0; v < 2; ++v)
{
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
}
return acc;
}
__device__ T operator()(ushort4_t a, ushort4_t b) const static_for<0, 4, 1>{}([&](auto i) {
{ acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a); });
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
T acc = 0;
for(index_t v = 0; v < 4; ++v)
{
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
}
return acc;
}
__device__ T operator()(ushort8_t a, ushort8_t b) const
{
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
T acc = 0;
for(index_t v = 0; v < 8; ++v)
{
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
}
return acc; return acc;
} }
}; };
......
...@@ -39,7 +39,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -39,7 +39,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 1
// run-time variables // run-time variables
const auto in_n_c_hi_wi_desc = const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
...@@ -368,6 +368,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -368,6 +368,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
#endif #endif
<BlockSize, <BlockSize,
TDevice,
TDevice, TDevice,
TDevice, TDevice,
GemmMPerBlock, GemmMPerBlock,
......
...@@ -3,20 +3,25 @@ ...@@ -3,20 +3,25 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
template <class T, template <class TInWei,
ck::index_t InWeiVectorSize,
class TAcc,
class TOut,
class InDesc, class InDesc,
class WeiDesc, class WeiDesc,
class OutDesc, class OutDesc,
class ConvStrides, class ConvStrides,
class ConvDilations, class ConvDilations,
class InLeftPads, class InLeftPads,
class InRightPads> class InRightPads,
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc, class T>
const Tensor<T>& in_nchw, void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
InDesc,
const Tensor<T>& in_n_c_hi_wi,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_k_c_y_x,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_n_k_ho_wo,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
InLeftPads, InLeftPads,
...@@ -28,8 +33,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -28,8 +33,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
using namespace ck; using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -48,12 +51,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -48,12 +51,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
constexpr auto Y = WeiDesc::GetLengths()[I2]; constexpr auto Y = WeiDesc::GetLengths()[I2];
constexpr auto X = WeiDesc::GetLengths()[I3]; constexpr auto X = WeiDesc::GetLengths()[I3];
constexpr auto C0 = C / Number<InWeiVectorSize>{};
constexpr auto C1 = Number<InWeiVectorSize>{};
#if 0 #if 0
// run-time variables // run-time variables
constexpr auto in_n_hi_wi_c_desc = constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C)); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
constexpr auto wei_k_y_x_c_desc = constexpr auto wei_k_y_x_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, Y, X, C)); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, Y, X, C0));
constexpr auto out_n_ho_wo_k_desc = constexpr auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Ho, Wo, K)); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Ho, Wo, K));
...@@ -63,10 +69,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -63,10 +69,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
const auto in_right_pads = to_multi_index(InRightPads{}); const auto in_right_pads = to_multi_index(InRightPads{});
#else #else
// compile-time variables // compile-time variables
constexpr auto in_n_hi_wi_c_desc = constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Hi, Wi, C)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Hi, Wi, C0));
constexpr auto wei_k_y_x_c_desc = constexpr auto wei_k_y_x_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y, X, C)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y, X, C0));
constexpr auto out_n_ho_wo_k_desc = constexpr auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Ho, Wo, K)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Ho, Wo, K));
...@@ -76,38 +82,36 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -76,38 +82,36 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
Tensor<float> in_nhwc( Tensor<TInWei> in_n_hi_wi_c(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{}))); make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{})));
Tensor<float> wei_kyxc( Tensor<TInWei> wei_k_y_x_c(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{}))); make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{})));
Tensor<float> out_nhwk( Tensor<TOut> out_n_ho_wo_k(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{}))); make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{})));
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) { auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi); in_n_hi_wi_c(n, hi, wi, c) = in_n_c_hi_wi(n, c, hi, wi);
}; };
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) { auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x); wei_k_y_x_c(k, y, x, c) = wei_k_c_y_x(k, c, y, x);
}; };
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) { auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo); out_n_ho_wo_k(n, ho, wo, k) = out_n_k_ho_wo(n, k, ho, wo);
}; };
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)();
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)();
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)();
std::size_t data_sz = sizeof(T);
DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace()); DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace()); DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace()); DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_nhwc_device_buf.ToDevice(in_nhwc.mData.data()); in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data()); wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_nhwk_device_buf.ToDevice(out_nhwk.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
#if 1 #if 1
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
...@@ -378,8 +382,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -378,8 +382,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
#endif #endif
<BlockSize, <BlockSize,
TDevice, typename vector_type<TInWei, InWeiVectorSize>::type,
TDevice, TAcc,
TOut,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
...@@ -400,22 +405,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -400,22 +405,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN,
GemmCThreadTransferDstScalarPerVector_GemmM1>{}; GemmCThreadTransferDstScalarPerVector_GemmM1>{};
conv_driver.Run(wei_k_y_x_c_desc, conv_driver.Run(wei_k_y_x_c0_desc,
in_n_hi_wi_c_desc, in_n_hi_wi_c0_desc,
out_n_ho_wo_k_desc, out_n_ho_wo_k_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
static_cast<TDevice*>(wei_kyxc_device_buf.GetDeviceBuffer()), static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
static_cast<TDevice*>(in_nhwc_device_buf.GetDeviceBuffer()), wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nhwk_device_buf.GetDeviceBuffer())); static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()));
out_nhwk_device_buf.FromDevice(out_nhwk.mData.data()); #if 1
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
#endif
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) { auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_nkhw(n, k, ho, wo) = out_nhwk(n, ho, wo, k); out_n_k_ho_wo(n, k, ho, wo) = out_n_ho_wo_k(n, ho, wo, k);
}; };
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)();
} }
...@@ -68,16 +68,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -68,16 +68,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
#endif #endif
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HPerBlock = 8; constexpr index_t HPerBlock = 16;
constexpr index_t WPerBlock = 8; constexpr index_t WPerBlock = 16;
constexpr index_t CYXPerBlock = 4; constexpr index_t CYXPerBlock = 4;
constexpr index_t KPerThread = 8; constexpr index_t KPerThread = 16;
constexpr index_t HPerThread = 1; constexpr index_t HPerThread = 2;
constexpr index_t WPerThread = 1; constexpr index_t WPerThread = 2;
constexpr index_t CYXPerThread = 4; constexpr index_t CYXPerThread = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
......
...@@ -158,7 +158,7 @@ struct ParallelTensorFunctor ...@@ -158,7 +158,7 @@ struct ParallelTensorFunctor
return indices; return indices;
} }
void operator()(std::size_t num_thread) const void operator()(std::size_t num_thread = std::thread::hardware_concurrency()) const
{ {
std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread; std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread;
......
...@@ -25,7 +25,21 @@ int main(int argc, char* argv[]) ...@@ -25,7 +25,21 @@ int main(int argc, char* argv[])
#if 0 #if 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 1;
constexpr index_t WI = 64;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
constexpr index_t WI = 1920; constexpr index_t WI = 1920;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -39,7 +53,7 @@ int main(int argc, char* argv[]) ...@@ -39,7 +53,7 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
constexpr index_t WI = 960; constexpr index_t WI = 960;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -53,7 +67,7 @@ int main(int argc, char* argv[]) ...@@ -53,7 +67,7 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 270; constexpr index_t HI = 270;
constexpr index_t WI = 480; constexpr index_t WI = 480;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -65,20 +79,6 @@ int main(int argc, char* argv[]) ...@@ -65,20 +79,6 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 4;
constexpr index_t HI = 1080;
constexpr index_t WI = 1920;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 4;
...@@ -95,7 +95,7 @@ int main(int argc, char* argv[]) ...@@ -95,7 +95,7 @@ int main(int argc, char* argv[])
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
constexpr index_t WI = 960; constexpr index_t WI = 960;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -109,7 +109,7 @@ int main(int argc, char* argv[]) ...@@ -109,7 +109,7 @@ int main(int argc, char* argv[])
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 270; constexpr index_t HI = 270;
constexpr index_t WI = 480; constexpr index_t WI = 480;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -631,12 +631,16 @@ int main(int argc, char* argv[]) ...@@ -631,12 +631,16 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 1 #if 0
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1;
using out_data_t = float; using out_data_t = float;
using acc_data_t = float;
#else #else
using in_data_t = half_float::half; using in_data_t = int8_t;
using out_data_t = half_float::half; constexpr index_t in_vector_size = 4;
using acc_data_t = int32_t;
using out_data_t = int8_t;
#endif #endif
Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc)); Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc));
...@@ -646,14 +650,15 @@ int main(int argc, char* argv[]) ...@@ -646,14 +650,15 @@ int main(int argc, char* argv[])
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
if(argc != 3) if(argc != 4)
{ {
printf("arg1: do_verification, arg2: nrepeat\n"); printf("arg1: do_verification, arg2: do_log, arg3: nrepeat\n");
exit(1); exit(1);
} }
bool do_verification = atoi(argv[1]); bool do_verification = atoi(argv[1]);
index_t nrepeat = atoi(argv[2]); bool do_log = atoi(argv[2]);
index_t nrepeat = atoi(argv[3]);
if(do_verification) if(do_verification)
{ {
...@@ -662,7 +667,7 @@ int main(int argc, char* argv[]) ...@@ -662,7 +667,7 @@ int main(int argc, char* argv[])
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_3{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
...@@ -751,20 +756,25 @@ int main(int argc, char* argv[]) ...@@ -751,20 +756,25 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size,
acc_data_t,
out_data_t>(
in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#endif #endif
if(do_verification) if(do_verification)
{
#if 0
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
{
host_winograd_3x3_convolution(
in_nchw, wei_kcyx, out_nkhw_host, LeftPads{}, RightPads{});
}
else
#endif
{ {
host_direct_convolution(in_nchw, host_direct_convolution(in_nchw,
wei_kcyx, wei_kcyx,
...@@ -773,14 +783,15 @@ int main(int argc, char* argv[]) ...@@ -773,14 +783,15 @@ int main(int argc, char* argv[])
ConvDilations{}, ConvDilations{},
LeftPads{}, LeftPads{},
RightPads{}); RightPads{});
}
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
#if 0 if(do_log)
// LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; {
// LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
#endif }
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment