Commit 0e77b53e authored by Jing Zhang's avatar Jing Zhang
Browse files

split k0 k1 in c_thread_grid

parent 40694062
......@@ -10,7 +10,7 @@ template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc_E1_K_E2,
typename ABlockDesc_E1_K1_E2,
typename BBlockDesc_E1_N_Ho_Wo_E2,
typename CThreadDesc_K_N_Ho_Wo,
index_t EPerThreadLoop,
......@@ -27,16 +27,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto E1 = ABlockDesc_E1_K_E2{}.GetLength(I0);
static constexpr auto KPerBlock = ABlockDesc_E1_K_E2{}.GetLength(I1);
static constexpr auto E2 = ABlockDesc_E1_K_E2{}.GetLength(I2);
static constexpr auto E1 = ABlockDesc_E1_K1_E2{}.GetLength(I0);
static constexpr auto KPerBlock = ABlockDesc_E1_K1_E2{}.GetLength(I1);
static constexpr auto E2 = ABlockDesc_E1_K1_E2{}.GetLength(I2);
static constexpr auto HPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
static constexpr auto WPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
static constexpr auto HoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
static constexpr auto WoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0);
static constexpr auto HPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2);
static constexpr auto WPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3);
static constexpr auto HoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2);
static constexpr auto WoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3);
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}, Number<E2>{}));
......@@ -44,37 +44,37 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static constexpr auto b_thread_mtx_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<EPerThreadLoop>{},
Number<1>{},
Number<HPerThread>{},
Number<WPerThread>{},
Number<HoPerThread>{},
Number<WoPerThread>{},
Number<E2>{}));
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
Number<KPerThreadLoop>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_origin_data_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * KPerThread, 0)}
{
static_assert(ABlockDesc_E1_K_E2::IsKnownAtCompileTime() &&
static_assert(ABlockDesc_E1_K1_E2::IsKnownAtCompileTime() &&
BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(
ABlockDesc_E1_K_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) &&
ABlockDesc_E1_K_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4),
ABlockDesc_E1_K1_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) &&
ABlockDesc_E1_K1_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4),
"wrong! E dimension not consistent\n");
static_assert(E1 % EPerThreadLoop == 0, "");
static_assert(KPerThread % KPerThreadLoop == 0, "");
static_assert(KPerBlock % KPerThread == 0 && HPerBlock % HPerThread == 0 &&
WPerBlock % WPerThread == 0,
static_assert(KPerBlock % KPerThread == 0 && HoPerBlock % HoPerThread == 0 &&
WoPerBlock % WoPerThread == 0,
"wrong! Cannot evenly divide work among\n");
constexpr auto KThreadCluster = KPerBlock / KPerThread;
constexpr auto HThreadCluster = HPerBlock / HPerThread;
constexpr auto WThreadCluster = WPerBlock / WPerThread;
constexpr auto HThreadCluster = HoPerBlock / HoPerThread;
constexpr auto WThreadCluster = WoPerBlock / WoPerThread;
static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
"wrong! wrong blocksize\n");
......@@ -82,15 +82,15 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
__device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths()
{
return Sequence<KPerThread, I1, HPerThread, WPerThread>{};
return Sequence<KPerThread, I1, HoPerThread, WoPerThread>{};
}
__device__ static CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id)
{
constexpr auto K0 = KPerBlock / KPerThread;
constexpr auto N0 = I1;
constexpr auto H0 = HPerBlock / HPerThread;
constexpr auto W0 = WPerBlock / WPerThread;
constexpr auto H0 = HoPerBlock / HoPerThread;
constexpr auto W0 = WoPerBlock / WoPerThread;
constexpr auto c_threadid_to_k_n_h_w_thread_cluster_adaptor =
make_single_stage_tensor_adaptor(
......@@ -116,7 +116,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_block_mtx = ABlockDesc_E1_K_E2{};
constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{};
// thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
......@@ -151,14 +151,14 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
template <typename ABlockSliceMoveStepIdx>
__device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{
a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K_E2{}, a_block_slice_move_step_idx);
a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K1_E2{}, a_block_slice_move_step_idx);
}
private:
using AThreadCopy =
ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc_E1_K_E2,
ABlockDesc_E1_K1_E2,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadLoop, E2>,
Sequence<0, 1, 2>,
......
......@@ -15,7 +15,7 @@ namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_E0_E1_K_E2,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
......@@ -27,7 +27,7 @@ __global__ void
kernel_gemm_dlops_v2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_E0_E1_K_E2 a_e0_e1_k_e2_grid_desc,
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_Ho_Wo_E2 b_e0_e1_n_ho_wo_e2_grid_desc,
const CGridDesc_K_N_Ho_Wo c_k_n_ho_wo_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
......@@ -42,7 +42,7 @@ __global__ void
p_b_grid,
p_c_grid,
p_shared_block,
a_e0_e1_k_e2_grid_desc,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_ho_wo_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor,
......@@ -55,7 +55,7 @@ __global__ void
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_E0_E1_K_E2,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
......@@ -67,7 +67,7 @@ __global__ void
kernel_gemm_dlops_v2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_e0_e1_k_e2_grid_desc,
const void CONSTANT* p_a_e0_e1_k0_k1_e2_grid_desc,
const void CONSTANT* p_b_e0_e1_n_ho_wo_e2_grid_desc,
const void CONSTANT* p_c_k_n_ho_wo_grid_desc,
const void CONSTANT* p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor)
......@@ -75,8 +75,8 @@ __global__ void
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_e0_e1_k_e2_grid_desc = *reinterpret_cast<const AGridDesc_E0_E1_K_E2*>(
cast_pointer_to_generic_address_space(p_a_e0_e1_k_e2_grid_desc));
const auto a_e0_e1_k0_k1_e2_grid_desc = *reinterpret_cast<const AGridDesc_E0_E1_K0_K1_E2*>(
cast_pointer_to_generic_address_space(p_a_e0_e1_k0_k1_e2_grid_desc));
const auto b_e0_e1_n_ho_wo_e2_grid_desc = *reinterpret_cast<const BGridDesc_E0_E1_N_Ho_Wo_E2*>(
cast_pointer_to_generic_address_space(p_b_e0_e1_n_ho_wo_e2_grid_desc));
const auto c_k_n_ho_wo_grid_desc = *reinterpret_cast<const CGridDesc_K_N_Ho_Wo*>(
......@@ -94,7 +94,7 @@ __global__ void
p_b_grid,
p_c_grid,
p_shared_block,
a_e0_e1_k_e2_grid_desc,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_ho_wo_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor,
......@@ -120,8 +120,8 @@ template <index_t BlockSize,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E0_E1_K_E2,
typename ABlockTransferThreadClusterLengths_E0_E1_K_E2,
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
......@@ -161,12 +161,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e0_e1_k_e2_block_desc = make_naive_tensor_descriptor_aligned(
constexpr auto a_e0_e1_k1_e2_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<E1>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = math::integer_least_multiple(
a_e0_e1_k_e2_block_desc.GetElementSpaceSize(), max_lds_align);
a_e0_e1_k1_e2_block_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(FloatAB);
}
......@@ -181,10 +181,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock;
const auto Ho_0 = Ho / HoPerBlock;
const auto Wo_0 = Wo / WoPerBlock;
const auto Ho0 = Ho / HoPerBlock;
const auto Wo0 = Wo / WoPerBlock;
const index_t grid_size = K0 * N0 * Ho_0 * Wo_0;
const index_t grid_size = K0 * N0 * Ho0 * Wo0;
return grid_size;
}
......@@ -214,9 +214,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
MakeAE0E1K0K1E2GridDescriptor(const AGridDesc_E0_E1_K_E2& a_e0_e1_k_e2_grid_desc)
{
const auto E0 = a_e0_e1_k_e2_grid_desc.GetLength(I0);
// const auto E1 = a_e0_e1_k_e2_grid_desc.GetLength(I1);
const auto K = a_e0_e1_k_e2_grid_desc.GetLength(I2);
// const auto E2 = a_e0_e1_k_e2_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
......@@ -233,6 +231,29 @@ struct GridwiseGemmDlops_km_kn_mn_v3
return a_e0_e1_k0_k1_e2_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCK0K1NHoWoGridDescriptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{
const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0);
const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1);
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto c_k0_k1_n_ho_wo_grid_desc = transform_tensor_descriptor(
c_k_n_ho_wo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pass_through_transform(Ho),
make_pass_through_transform(Wo)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
return c_k0_k1_n_ho_wo_grid_desc;
}
__host__ __device__ static constexpr auto MakeBE0E1NH0H1W0W1E2GridDescriptor(
const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_grid_desc)
{
......@@ -283,17 +304,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock;
const auto Ho_0 = Ho / HoPerBlock;
const auto Wo_0 = Wo / WoPerBlock;
const auto Ho0 = Ho / HoPerBlock;
const auto Wo0 = Wo / WoPerBlock;
const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(K0, N0, Ho_0, Wo_0))),
make_tuple(make_merge_transform(make_tuple(K0, N0, Ho0, Wo0))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_k_n_ho_wo_block_cluster_adaptor;
}
using AGridDesc_E0_E1_K0_K1_E2 =
decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{}));
using CGridDesc_K0_K1_N_Ho_Wo = decltype(MakeCK0K1NHoWoGridDescriptor(CGridDesc_K_N_Ho_Wo{}));
using CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo =
decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{}));
......@@ -303,24 +327,24 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K_E2& a_e0_e1_k_e2_global_desc,
const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_global_desc,
const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_global_desc,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_grid_desc,
const CGridDesc_K0_K1_N_Ho_Wo& c_k0_k1_n_ho_wo_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo& c_blockid_to_k_n_ho_wo_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k_e2_global_desc.GetElementSpaceSize());
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e0_e1_n_ho_wo_e2_global_desc.GetElementSpaceSize());
p_b_global, b_e0_e1_n_ho_wo_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
p_c_global, c_k0_k1_n_ho_wo_grid_desc.GetElementSpaceSize());
constexpr auto HasMainE1BlockLoop = CalculateHasMainE1BlockLoop();
constexpr auto HasDoubleTailE1BlockLoop = CalculateHasDoubleTailE1BlockLoop();
// const auto Ho = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I3);
// const auto Wo = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I4);
// const auto Ho = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I3);
// const auto Wo = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I4);
const auto c_k_n_ho_wo_block_cluster_idx =
c_blockid_to_k_n_ho_wo_block_cluster_adaptor.CalculateBottomIndex(
......@@ -335,15 +359,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const index_t wo_block_work_id =
__builtin_amdgcn_readfirstlane(c_k_n_ho_wo_block_cluster_idx[I3]);
// lds max alignment
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e0_e1_k_e2_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<I1>{}, Number<E1>{}, Number<KPerBlock>{}, Number<E2>{}),
max_lds_align);
// B matrix in thread, dst of blockwise copy
constexpr auto b_e1_n_ho_wo_e2_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<E1PerBlock>{},
......@@ -354,10 +371,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
constexpr auto c_k1_n_ho_wo_thread_gemm_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, I1, Number<HoPerThread>{}, Number<WoPerThread>{}));
constexpr auto a_e1_k_e2_block_desc = make_naive_tensor_descriptor_aligned(
constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E1PerBlock>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align);
auto blockwise_gemm =
......@@ -365,21 +382,21 @@ struct GridwiseGemmDlops_km_kn_mn_v3
FloatAB,
FloatAB,
FloatAcc,
decltype(a_e1_k_e2_block_desc),
decltype(a_e1_k1_e2_block_gemm_desc),
decltype(b_e1_n_ho_wo_e2_block_desc),
decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k1_n_ho_wo_thread_gemm_desc),
EPerThread,
K2>{};
auto c_thread_mtx_index =
blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id());
const auto k_thread_id = c_thread_mtx_index[I0];
// const auto k_thread_id = c_thread_mtx_index[I0];
const auto ho_thread_id = c_thread_mtx_index[I2];
const auto wo_thread_id = c_thread_mtx_index[I3];
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t n_block_data_on_global = n_block_work_id * HoPerBlock;
// const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
// const index_t n_block_data_on_global = n_block_work_id * HoPerBlock;
const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock;
const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock;
......@@ -389,34 +406,37 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const index_t wo_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread;
constexpr auto a_e0_e1_k0_k1_e2_block_copy_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<I1>{}, Number<E1>{}, I1, Number<KPerBlock>{}, Number<E2>{}),
max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<I1, E1, KPerBlock, E2>,
ABlockTransferThreadSliceLengths_E0_E1_K_E2,
ABlockTransferThreadClusterLengths_E0_E1_K_E2,
Sequence<I1, E1, I1, KPerBlock, E2>,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_e0_e1_k_e2_global_desc),
decltype(a_e0_e1_k_e2_block_desc),
decltype(a_e0_e1_k0_k1_e2_grid_desc),
decltype(a_e0_e1_k0_k1_e2_block_copy_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>, // ABlockTransferDstAccessOrder
Sequence<0, 1, 2, 3, 4>, // ABlockTransferDstAccessOrder
ABlockTransferSrcVectorDim,
3, // ABlockTransferDstVectorDim
4, // ABlockTransferDstVectorDim
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_E2,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
false>(
a_e0_e1_k_e2_global_desc,
make_multi_index(0, 0, k_block_data_on_global, 0),
a_e0_e1_k_e2_block_desc,
make_multi_index(0, 0, 0, 0));
false>(a_e0_e1_k0_k1_e2_grid_desc,
make_multi_index(0, 0, k_block_work_id, 0, 0),
a_e0_e1_k0_k1_e2_block_copy_desc,
make_multi_index(0, 0, 0, 0, 0));
constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0, 0);
constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0, 0, 0);
constexpr auto b_e0_e1_n_ho_wo_e2_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1,
......@@ -429,14 +449,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
auto b_threadwise_transfer = ThreadwiseTensorSliceTransfer_v2<
FloatAB,
FloatAB,
decltype(b_e0_e1_n_ho_wo_e2_global_desc),
decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
decltype(b_e0_e1_n_ho_wo_e2_thread_desc),
Sequence<I1, E1PerBlock, NPerBlock, HoPerThread, WoPerThread, E2>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_e0_e1_n_ho_wo_e2_global_desc,
true>(b_e0_e1_n_ho_wo_e2_grid_desc,
make_multi_index(0,
0,
n_thread_data_on_global,
......@@ -445,20 +465,23 @@ struct GridwiseGemmDlops_km_kn_mn_v3
0));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_shared_block, a_e0_e1_k_e2_block_desc.GetElementSpaceSize());
p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize());
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc,
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
c_k1_n_ho_wo_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
// initialize output thread tensor
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k1_n_ho_wo_thread_gemm_desc),
Sequence<KPerThread, NPerBlock, HoPerThread, WoPerThread>>{}
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
.Run(c_k1_n_ho_wo_thread_gemm_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto b_thread_slice_copy_step = make_multi_index(0, E1PerBlock, 0, 0, 0, 0);
......@@ -475,7 +498,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
if constexpr(HasMainE0BlockLoop)
{
const auto E0 = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I0);
const auto E0 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I0);
index_t e0_block_data_begin = 0;
......@@ -484,16 +507,16 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: preload data
{
a_blockwise_copy.RunRead(
a_e0_e1_k_e2_global_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc,
b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks);
a_blockwise_copy.RunWrite(a_e0_e1_k_e2_block_desc, a_block_buf);
a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf);
}
__syncthreads();
......@@ -507,11 +530,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc,
b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
......@@ -523,11 +546,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc,
b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
......@@ -547,11 +570,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: tail
if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc,
b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
......@@ -572,13 +595,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
a_blockwise_copy.MoveSrcSliceWindow(a_e0_e1_k_e2_global_desc,
a_blockwise_copy.MoveSrcSliceWindow(a_e0_e1_k0_k1_e2_grid_desc,
a_block_slice_copy_step,
AGlobalMoveSliceWindowStepHacks{});
blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - E1PerBlock), 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
......@@ -591,16 +614,16 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: preload data
{
a_blockwise_copy.RunRead(
a_e0_e1_k_e2_global_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc,
b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks);
a_blockwise_copy.RunWrite(a_e0_e1_k_e2_block_desc, a_block_buf);
a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf);
}
__syncthreads();
......@@ -614,11 +637,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc,
b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
......@@ -630,11 +653,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc,
b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
......@@ -654,11 +677,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: tail
if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc,
b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
......@@ -682,7 +705,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// activ
{
static_for<0, c_k_n_ho_wo_thread_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
static_for<0, c_k1_n_ho_wo_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type == 1)
{
c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : 0.0;
......@@ -706,28 +729,31 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread;
constexpr auto c_k0_k1_n_ho_wo_thread_copy_desc =
make_naive_tensor_descriptor_packed(make_tuple(
I1, Number<KPerThread>{}, I1, Number<HoPerThread>{}, Number<WoPerThread>{}));
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k_n_ho_wo_global_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
decltype(c_k0_k1_n_ho_wo_thread_copy_desc),
decltype(c_k0_k1_n_ho_wo_grid_desc),
Sequence<I1, KPerThread, I1, HoPerThread, WoPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>(c_k_n_ho_wo_global_desc,
make_multi_index(k_thread_data_on_global,
true>(c_k0_k1_n_ho_wo_grid_desc,
make_multi_index(k_block_work_id,
0,
n_thread_data_on_global,
ho_thread_data_on_global,
wo_thread_data_on_global))
.Run(c_k_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
.Run(c_k0_k1_n_ho_wo_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0),
c_thread_buf,
c_k_n_ho_wo_global_desc,
c_k0_k1_n_ho_wo_grid_desc,
c_global_buf,
c_k_n_ho_wo_global_tensor_step_hacks);
}
......
......@@ -99,15 +99,15 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t E1 = C0 * 9;
constexpr index_t E2 = 1;
constexpr index_t EPerBlock = C0;
constexpr index_t E1PerBlock = C0;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, EPerBlock, KPerBlock, 1>;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
......@@ -124,15 +124,16 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t E1 = 2 * 9;
constexpr index_t E2 = 1;
constexpr index_t EPerBlock = 2;
constexpr index_t E1PerBlock = 2;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, EPerBlock, KPerBlock, 1>;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 =
Sequence<1, E1PerBlock, 1, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
......@@ -153,13 +154,13 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
KPerBlock,
HoPerBlock,
WoPerBlock,
EPerBlock,
E1PerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K_E2,
ABlockTransferThreadClusterLengths_E0_E1_K_E2,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2,
BThreadTransferSrcScalarPerVector_E2,
......
......@@ -20,8 +20,8 @@ template <ck::index_t BlockSize,
ck::index_t HoPerThread,
ck::index_t WoPerThread,
ck::index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E0_E1_K_E2,
typename ABlockTransferThreadClusterLengths_E0_E1_K_E2,
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ck::index_t ABlockTransferSrcScalarPerVector_E2,
ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_E2,
......@@ -77,11 +77,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
// const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
// const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
// const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
// const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo;
......@@ -92,11 +92,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW
<< std::endl;
std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW
<< std::endl;
const auto E = C0 * Y * X;
constexpr auto E1 = Number<E1_>{};
......@@ -188,17 +183,19 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_e0_e1_k_e2_global_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0>{}));
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0>{};
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
constexpr auto b_e0_e1_n_ho_wo_e2_global_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
......@@ -220,18 +217,20 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
// static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), "");
// static_assert(b_e0_e1_n_ho_wo_e2_grid_desc.IsKnownAtCompileTime(), "");
// static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(b_e0_e1_n_ho_wo_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
// GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
......@@ -253,11 +252,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K_E2,
ABlockTransferThreadClusterLengths_E0_E1_K_E2,
Sequence<2, 0, 1, 3>,
Sequence<0, 1, 2, 3>,
3,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
Sequence<2, 3, 0, 1, 4>,
Sequence<0, 1, 2, 3, 4>,
4,
ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy
......@@ -266,8 +265,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
BThreadTransferSrcScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<0, 1, 2, 3>,
0,
Sequence<0, 1, 2, 3, 4>,
1,
CThreadTransferDstScalarPerVector_K,
decltype(a_e0_e1_k_e2_global_step_hacks),
decltype(b_e0_e1_n_ho_wo_e2_global_step_hacks),
......@@ -276,9 +275,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
decltype(b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack),
activ_type>;
using AGridDesc_E0_E1_K_E2 = decltype(a_e0_e1_k_e2_grid_desc);
const auto a_e0_e1_k0_k1_e2_grid_desc =
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
const auto c_k0_k1_n_hop_wop_grid_desc =
GridwiseGemm::MakeCK0K1NHoWoGridDescriptor(c_k_n_hop_wop_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_Ho_Wo_E2 = decltype(b_e0_e1_n_ho_wo_e2_grid_desc);
using CGridDesc_K_N_Ho_Wo = decltype(c_k_n_hop_wop_grid_desc);
using CGridDesc_K0_K1_N_Ho_Wo = decltype(c_k0_k1_n_hop_wop_grid_desc);
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
......@@ -299,9 +303,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CGridDesc_K0_K1_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_e0_block_loop,
has_main_e1_block_loop,
......@@ -315,21 +319,21 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid,
p_b_grid,
p_c_grid,
a_e0_e1_k_e2_grid_desc,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc,
c_k0_k1_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_e0_e1_k_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K_E2));
DeviceMem a_e0_e1_k0_k1_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K0_K1_E2));
DeviceMem b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf(sizeof(BGridDesc_E0_E1_N_Ho_Wo_E2));
DeviceMem c_k_n_hop_wop_grid_desc_dev_buf(sizeof(CGridDesc_K_N_Ho_Wo));
DeviceMem c_k0_k1_n_hop_wop_grid_desc_dev_buf(sizeof(CGridDesc_K0_K1_N_Ho_Wo));
DeviceMem c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf(
sizeof(CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo));
a_e0_e1_k_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k_e2_grid_desc);
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k0_k1_e2_grid_desc);
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.ToDevice(&b_e0_e1_n_ho_wo_e2_grid_desc);
c_k_n_hop_wop_grid_desc_dev_buf.ToDevice(&c_k_n_hop_wop_grid_desc);
c_k0_k1_n_hop_wop_grid_desc_dev_buf.ToDevice(&c_k0_k1_n_hop_wop_grid_desc);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice(
&c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
......@@ -340,9 +344,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CGridDesc_K0_K1_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true>;
......@@ -356,11 +360,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(
a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
c_k0_k1_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
......@@ -371,9 +375,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CGridDesc_K0_K1_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false>;
......@@ -387,11 +391,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(
a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
c_k0_k1_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
......
......@@ -20,7 +20,7 @@
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1
#define USE_DYNAMIC_MODE 0
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment