Commit 0e77b53e authored by Jing Zhang's avatar Jing Zhang
Browse files

split k0 k1 in c_thread_grid

parent 40694062
...@@ -10,7 +10,7 @@ template <index_t BlockSize, ...@@ -10,7 +10,7 @@ template <index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename ABlockDesc_E1_K_E2, typename ABlockDesc_E1_K1_E2,
typename BBlockDesc_E1_N_Ho_Wo_E2, typename BBlockDesc_E1_N_Ho_Wo_E2,
typename CThreadDesc_K_N_Ho_Wo, typename CThreadDesc_K_N_Ho_Wo,
index_t EPerThreadLoop, index_t EPerThreadLoop,
...@@ -27,16 +27,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -27,16 +27,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
using BIndex = MultiIndex<3>; using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>; using CIndex = MultiIndex<4>;
static constexpr auto E1 = ABlockDesc_E1_K_E2{}.GetLength(I0); static constexpr auto E1 = ABlockDesc_E1_K1_E2{}.GetLength(I0);
static constexpr auto KPerBlock = ABlockDesc_E1_K_E2{}.GetLength(I1); static constexpr auto KPerBlock = ABlockDesc_E1_K1_E2{}.GetLength(I1);
static constexpr auto E2 = ABlockDesc_E1_K_E2{}.GetLength(I2); static constexpr auto E2 = ABlockDesc_E1_K1_E2{}.GetLength(I2);
static constexpr auto HPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2); static constexpr auto HoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
static constexpr auto WPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3); static constexpr auto WoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0); static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0);
static constexpr auto HPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2); static constexpr auto HoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2);
static constexpr auto WPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3); static constexpr auto WoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3);
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed( static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}, Number<E2>{})); make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}, Number<E2>{}));
...@@ -44,37 +44,37 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -44,37 +44,37 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static constexpr auto b_thread_mtx_ = static constexpr auto b_thread_mtx_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<EPerThreadLoop>{}, make_naive_tensor_descriptor_packed(make_tuple(Number<EPerThreadLoop>{},
Number<1>{}, Number<1>{},
Number<HPerThread>{}, Number<HoPerThread>{},
Number<WPerThread>{}, Number<WoPerThread>{},
Number<E2>{})); Number<E2>{}));
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<KPerThreadLoop>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_origin_data_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())}, : c_thread_origin_data_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * KPerThread, 0)} a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * KPerThread, 0)}
{ {
static_assert(ABlockDesc_E1_K_E2::IsKnownAtCompileTime() && static_assert(ABlockDesc_E1_K1_E2::IsKnownAtCompileTime() &&
BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() && BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(), CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert( static_assert(
ABlockDesc_E1_K_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) && ABlockDesc_E1_K1_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) &&
ABlockDesc_E1_K_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4), ABlockDesc_E1_K1_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4),
"wrong! E dimension not consistent\n"); "wrong! E dimension not consistent\n");
static_assert(E1 % EPerThreadLoop == 0, ""); static_assert(E1 % EPerThreadLoop == 0, "");
static_assert(KPerThread % KPerThreadLoop == 0, ""); static_assert(KPerThread % KPerThreadLoop == 0, "");
static_assert(KPerBlock % KPerThread == 0 && HPerBlock % HPerThread == 0 && static_assert(KPerBlock % KPerThread == 0 && HoPerBlock % HoPerThread == 0 &&
WPerBlock % WPerThread == 0, WoPerBlock % WoPerThread == 0,
"wrong! Cannot evenly divide work among\n"); "wrong! Cannot evenly divide work among\n");
constexpr auto KThreadCluster = KPerBlock / KPerThread; constexpr auto KThreadCluster = KPerBlock / KPerThread;
constexpr auto HThreadCluster = HPerBlock / HPerThread; constexpr auto HThreadCluster = HoPerBlock / HoPerThread;
constexpr auto WThreadCluster = WPerBlock / WPerThread; constexpr auto WThreadCluster = WoPerBlock / WoPerThread;
static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster, static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
"wrong! wrong blocksize\n"); "wrong! wrong blocksize\n");
...@@ -82,15 +82,15 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -82,15 +82,15 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
__device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths() __device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths()
{ {
return Sequence<KPerThread, I1, HPerThread, WPerThread>{}; return Sequence<KPerThread, I1, HoPerThread, WoPerThread>{};
} }
__device__ static CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id) __device__ static CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id)
{ {
constexpr auto K0 = KPerBlock / KPerThread; constexpr auto K0 = KPerBlock / KPerThread;
constexpr auto N0 = I1; constexpr auto N0 = I1;
constexpr auto H0 = HPerBlock / HPerThread; constexpr auto H0 = HoPerBlock / HoPerThread;
constexpr auto W0 = WPerBlock / WPerThread; constexpr auto W0 = WoPerBlock / WoPerThread;
constexpr auto c_threadid_to_k_n_h_w_thread_cluster_adaptor = constexpr auto c_threadid_to_k_n_h_w_thread_cluster_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
...@@ -116,7 +116,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -116,7 +116,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value && is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type"); "wrong! inconsistent type");
constexpr auto a_block_mtx = ABlockDesc_E1_K_E2{}; constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{};
// thread A buffer for GEMM // thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true> StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
...@@ -151,14 +151,14 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -151,14 +151,14 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
template <typename ABlockSliceMoveStepIdx> template <typename ABlockSliceMoveStepIdx>
__device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx) __device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{ {
a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K_E2{}, a_block_slice_move_step_idx); a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K1_E2{}, a_block_slice_move_step_idx);
} }
private: private:
using AThreadCopy = using AThreadCopy =
ThreadwiseTensorSliceTransfer_v4<FloatA, ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA, FloatA,
ABlockDesc_E1_K_E2, ABlockDesc_E1_K1_E2,
decltype(a_thread_mtx_), decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadLoop, E2>, Sequence<EPerThreadLoop, KPerThreadLoop, E2>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
......
...@@ -99,15 +99,15 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -99,15 +99,15 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t E1 = C0 * 9; constexpr index_t E1 = C0 * 9;
constexpr index_t E2 = 1; constexpr index_t E2 = 1;
constexpr index_t EPerBlock = C0; constexpr index_t E1PerBlock = C0;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1; constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, E2>; using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, EPerBlock, KPerBlock, 1>; using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
...@@ -122,17 +122,18 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -122,17 +122,18 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t HoPerBlock = 8; constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr index_t E1 = 2 * 9; constexpr index_t E1 = 2 * 9;
constexpr index_t E2 = 1; constexpr index_t E2 = 1;
constexpr index_t EPerBlock = 2; constexpr index_t E1PerBlock = 2;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1; constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, E2>; using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, EPerBlock, KPerBlock, 1>; using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 =
Sequence<1, E1PerBlock, 1, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
...@@ -153,13 +154,13 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -153,13 +154,13 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
EPerBlock, E1PerBlock,
KPerThread, KPerThread,
HoPerThread, HoPerThread,
WoPerThread, WoPerThread,
EPerThread, EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K_E2, ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K_E2, ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ABlockTransferSrcScalarPerVector_E2, ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2, ABlockTransferDstScalarPerVector_E2,
BThreadTransferSrcScalarPerVector_E2, BThreadTransferSrcScalarPerVector_E2,
......
...@@ -20,8 +20,8 @@ template <ck::index_t BlockSize, ...@@ -20,8 +20,8 @@ template <ck::index_t BlockSize,
ck::index_t HoPerThread, ck::index_t HoPerThread,
ck::index_t WoPerThread, ck::index_t WoPerThread,
ck::index_t EPerThread, ck::index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E0_E1_K_E2, typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
typename ABlockTransferThreadClusterLengths_E0_E1_K_E2, typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ck::index_t ABlockTransferSrcScalarPerVector_E2, ck::index_t ABlockTransferSrcScalarPerVector_E2,
ck::index_t ABlockTransferDstScalarPerVector_E2, ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_E2, ck::index_t BThreadTransferSrcScalarPerVector_E2,
...@@ -77,11 +77,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -77,11 +77,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
const auto ConvDilationH = conv_dilations[I0]; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1]; const auto ConvDilationW = conv_dilations[I1];
// const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{}; const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
// const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{}; const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; // const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; // const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
const auto OutRightPadH = Hop - Ho; const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo; const auto OutRightPadW = Wop - Wo;
...@@ -92,11 +92,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -92,11 +92,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH; const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW; const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW
<< std::endl;
std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW
<< std::endl;
const auto E = C0 * Y * X; const auto E = C0 * Y * X;
constexpr auto E1 = Number<E1_>{}; constexpr auto E1 = Number<E1_>{};
...@@ -188,17 +183,19 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -188,17 +183,19 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
// hack to control index calculation when iterating over a_k_m_global tensor // hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_e0_e1_k_e2_global_step_hacks = constexpr auto a_e0_e1_k_e2_global_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
Sequence<0, 0, 0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack = constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
constexpr auto b_e0_e1_n_ho_wo_e2_global_step_hacks = make_tuple( constexpr auto b_e0_e1_n_ho_wo_e2_global_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
...@@ -220,18 +217,20 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -220,18 +217,20 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format // hack for NKHW format
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
make_tuple(Sequence<0, 2, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
// static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), ""); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
// static_assert(b_e0_e1_n_ho_wo_e2_grid_desc.IsKnownAtCompileTime(), "");
// static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), ""); static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(b_e0_e1_n_ho_wo_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
// GEMM // GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3< using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
...@@ -253,11 +252,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -253,11 +252,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
HoPerThread, HoPerThread,
WoPerThread, WoPerThread,
EPerThread, EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K_E2, ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K_E2, ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
Sequence<2, 0, 1, 3>, Sequence<2, 3, 0, 1, 4>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3, 4>,
3, 4,
ABlockTransferSrcScalarPerVector_E2, ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2, ABlockTransferDstScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
...@@ -266,8 +265,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -266,8 +265,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
BThreadTransferSrcScalarPerVector_E2, BThreadTransferSrcScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3, 4>,
0, 1,
CThreadTransferDstScalarPerVector_K, CThreadTransferDstScalarPerVector_K,
decltype(a_e0_e1_k_e2_global_step_hacks), decltype(a_e0_e1_k_e2_global_step_hacks),
decltype(b_e0_e1_n_ho_wo_e2_global_step_hacks), decltype(b_e0_e1_n_ho_wo_e2_global_step_hacks),
...@@ -276,9 +275,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -276,9 +275,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
decltype(b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack), decltype(b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack),
activ_type>; activ_type>;
using AGridDesc_E0_E1_K_E2 = decltype(a_e0_e1_k_e2_grid_desc); const auto a_e0_e1_k0_k1_e2_grid_desc =
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
const auto c_k0_k1_n_hop_wop_grid_desc =
GridwiseGemm::MakeCK0K1NHoWoGridDescriptor(c_k_n_hop_wop_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_Ho_Wo_E2 = decltype(b_e0_e1_n_ho_wo_e2_grid_desc); using BGridDesc_E0_E1_N_Ho_Wo_E2 = decltype(b_e0_e1_n_ho_wo_e2_grid_desc);
using CGridDesc_K_N_Ho_Wo = decltype(c_k_n_hop_wop_grid_desc); using CGridDesc_K0_K1_N_Ho_Wo = decltype(c_k0_k1_n_hop_wop_grid_desc);
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
...@@ -299,9 +303,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -299,9 +303,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K0_K1_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_e0_block_loop, has_main_e0_block_loop,
has_main_e1_block_loop, has_main_e1_block_loop,
...@@ -315,21 +319,21 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -315,21 +319,21 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e0_e1_k_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc, c_k0_k1_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_e0_e1_k_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K_E2)); DeviceMem a_e0_e1_k0_k1_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K0_K1_E2));
DeviceMem b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf(sizeof(BGridDesc_E0_E1_N_Ho_Wo_E2)); DeviceMem b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf(sizeof(BGridDesc_E0_E1_N_Ho_Wo_E2));
DeviceMem c_k_n_hop_wop_grid_desc_dev_buf(sizeof(CGridDesc_K_N_Ho_Wo)); DeviceMem c_k0_k1_n_hop_wop_grid_desc_dev_buf(sizeof(CGridDesc_K0_K1_N_Ho_Wo));
DeviceMem c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf( DeviceMem c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf(
sizeof(CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo)); sizeof(CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo));
a_e0_e1_k_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k_e2_grid_desc); a_e0_e1_k0_k1_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k0_k1_e2_grid_desc);
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.ToDevice(&b_e0_e1_n_ho_wo_e2_grid_desc); b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.ToDevice(&b_e0_e1_n_ho_wo_e2_grid_desc);
c_k_n_hop_wop_grid_desc_dev_buf.ToDevice(&c_k_n_hop_wop_grid_desc); c_k0_k1_n_hop_wop_grid_desc_dev_buf.ToDevice(&c_k0_k1_n_hop_wop_grid_desc);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice( c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice(
&c_blockid_to_k_n_ho_wo_block_cluster_adaptor); &c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
...@@ -340,9 +344,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -340,9 +344,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K0_K1_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true>; true>;
...@@ -356,11 +360,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -356,11 +360,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k0_k1_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
...@@ -371,9 +375,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -371,9 +375,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K0_K1_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false>; false>;
...@@ -387,11 +391,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -387,11 +391,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k0_k1_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 0
#define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V6R1_NCHW 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment