Commit e6a23d8b authored by Jing Zhang's avatar Jing Zhang
Browse files

add e2

parent a8169558
...@@ -7,20 +7,21 @@ ...@@ -7,20 +7,21 @@
namespace ck { namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatC, typename FloatC,
typename BlockMatrixA, typename ABlockDesc_E1_K_E2,
typename BlockMatrixB, typename BBlockDesc_E1_N_Ho_Wo_E2,
typename ThreadMatrixC, typename CThreadDesc_K_N_Ho_Wo,
index_t EPerThreadLoop, index_t EPerThreadLoop,
index_t ThreadGemmADataPerRead_K, index_t ThreadGemmADataPerRead_E2>
index_t ThreadGemmBDataPerRead_W>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
struct MatrixIndex struct MatrixIndex
{ {
...@@ -29,36 +30,48 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -29,36 +30,48 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
index_t w; index_t w;
}; };
static constexpr index_t KPerThreadLoop = 4; static constexpr auto E1 = ABlockDesc_E1_K_E2{}.GetLength(I0);
static constexpr auto K = ABlockDesc_E1_K_E2{}.GetLength(I1);
static constexpr auto E2 = ABlockDesc_E1_K_E2{}.GetLength(I2);
static constexpr auto KPerThread = ThreadMatrixC{}.GetLength(I0); static constexpr auto H = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
static constexpr auto HPerThread = ThreadMatrixC{}.GetLength(I2); static constexpr auto W = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
static constexpr auto WPerThread = ThreadMatrixC{}.GetLength(I3);
static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0);
static constexpr auto HPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2);
static constexpr auto WPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3);
static constexpr index_t KPerThreadLoop = KPerThread;
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed( static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{})); make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}, Number<E2>{}));
static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( static constexpr auto b_thread_mtx_ =
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<EPerThreadLoop>{},
Number<1>{},
Number<HPerThread>{},
Number<WPerThread>{},
Number<E2>{}));
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<KPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, : c_thread_begin_mtx_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)} a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread, 0)}
{ {
static_assert(BlockMatrixA::IsKnownAtCompileTime() && static_assert(ABlockDesc_E1_K_E2::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() && BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
ThreadMatrixC::IsKnownAtCompileTime(), CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), static_assert(
"wrong! K dimension not consistent\n"); ABlockDesc_E1_K_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) &&
ABlockDesc_E1_K_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4),
"wrong! E dimension not consistent\n");
constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed static_assert(E1 % EPerThreadLoop == 0, "");
constexpr index_t H = BlockMatrixB{}.GetLength(I2); static_assert(KPerThread % KPerThreadLoop == 0, "");
constexpr index_t W = BlockMatrixB{}.GetLength(I3);
static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0, static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0,
"wrong! Cannot evenly divide work among\n"); "wrong! Cannot evenly divide work among\n");
...@@ -71,15 +84,15 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -71,15 +84,15 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
"wrong! wrong blocksize\n"); "wrong! wrong blocksize\n");
} }
__device__ static constexpr auto GetThreadMatrixCLengths() __device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths()
{ {
return Sequence<KPerThread, 1, HPerThread, WPerThread>{}; return Sequence<KPerThread, 1, HPerThread, WPerThread>{};
} }
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) __device__ static MatrixIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id)
{ {
constexpr index_t HPerBlock = BlockMatrixB{}.GetLength(Number<2>{}); constexpr index_t HPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
constexpr index_t WPerBlock = BlockMatrixB{}.GetLength(Number<3>{}); constexpr index_t WPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
constexpr auto num_w_threads = WPerBlock / WPerThread; constexpr auto num_w_threads = WPerBlock / WPerThread;
constexpr auto num_h_threads = HPerBlock / HPerThread; constexpr auto num_h_threads = HPerBlock / HPerThread;
...@@ -100,42 +113,37 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -100,42 +113,37 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
static_assert( static_assert(
is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatAB>>::value && is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatAB>>::value && is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value && is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type"); "wrong! inconsistent type");
constexpr auto a_block_mtx = BlockMatrixA{}; constexpr auto a_block_mtx = ABlockDesc_E1_K_E2{};
constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
static_assert(EPerBlock % EPerThreadLoop == 0, "");
static_assert(KPerThread % KPerThreadLoop == 0, "");
// thread A buffer for GEMM // thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatAB, a_thread_mtx_.GetElementSpaceSize(), true> StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf; a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatAB, constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
FloatAB, FloatB,
FloatC, FloatC,
decltype(a_thread_mtx_), decltype(a_thread_mtx_),
decltype(b_thread_mtx_), decltype(b_thread_mtx_),
decltype(c_thread_mtx_)>{}; decltype(c_thread_mtx_)>{};
static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) { static_for<0, E1, EPerThreadLoop>{}([&](auto e_begin) {
static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) { static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) {
a_thread_copy_.Run(a_block_mtx, a_thread_copy_.Run(a_block_mtx,
make_tuple(e_begin, k_begin), make_tuple(e_begin, k_begin, I0),
a_block_buf, a_block_buf,
a_thread_mtx_, a_thread_mtx_,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
a_thread_buf); a_thread_buf);
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0), make_tuple(I0, I0, I0),
b_thread_buf, b_thread_buf,
make_tuple(e_begin, I0, I0, I0), make_tuple(e_begin, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
make_tuple(k_begin, I0, I0, I0)); make_tuple(k_begin, I0, I0, I0));
}); });
...@@ -145,21 +153,22 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -145,21 +153,22 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
template <typename ABlockSliceMoveStepIdx> template <typename ABlockSliceMoveStepIdx>
__device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx) __device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{ {
a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx); a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K_E2{}, a_block_slice_move_step_idx);
} }
private: private:
MatrixIndex c_thread_begin_mtx_idx_; MatrixIndex c_thread_begin_mtx_idx_;
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using AThreadCopy =
FloatAB, ThreadwiseTensorSliceTransfer_v4<FloatA,
BlockMatrixA, FloatB,
decltype(a_thread_mtx_), ABlockDesc_E1_K_E2,
Sequence<EPerThreadLoop, KPerThreadLoop>, decltype(a_thread_mtx_),
Sequence<0, 1>, Sequence<EPerThreadLoop, KPerThreadLoop, E2>,
1, Sequence<0, 1, 2>,
ThreadGemmADataPerRead_K, 2,
1>; ThreadGemmADataPerRead_E2,
ThreadGemmADataPerRead_E2>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
}; };
......
...@@ -15,8 +15,8 @@ namespace ck { ...@@ -15,8 +15,8 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AGridDesc_E0_E1_K, typename AGridDesc_E0_E1_K_E2,
typename BGridDesc_E_N_Ho_Wo, typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo, typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
...@@ -28,8 +28,8 @@ __global__ void ...@@ -28,8 +28,8 @@ __global__ void
kernel_gemm_dlops_v2(const FloatAB* __restrict__ p_a_grid, kernel_gemm_dlops_v2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_E0_E1_K a_e0_e1_k_grid_desc, const AGridDesc_E0_E1_K_E2 a_e0_e1_k_e2_grid_desc,
const BGridDesc_E_N_Ho_Wo b_e0_e1_n_ho_wo_grid_desc, const BGridDesc_E0_E1_N_Ho_Wo_E2 b_e0_e1_n_ho_wo_e2_grid_desc,
const CGridDesc_K_N_Ho_Wo c_k_n_ho_wo_grid_desc, const CGridDesc_K_N_Ho_Wo c_k_n_ho_wo_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo const CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
c_blockid_to_k_n_ho_wo_block_cluster_adaptor) c_blockid_to_k_n_ho_wo_block_cluster_adaptor)
...@@ -43,8 +43,8 @@ __global__ void ...@@ -43,8 +43,8 @@ __global__ void
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
a_e0_e1_k_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_ho_wo_grid_desc, c_k_n_ho_wo_grid_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
...@@ -56,8 +56,8 @@ __global__ void ...@@ -56,8 +56,8 @@ __global__ void
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AGridDesc_E0_E1_K, typename AGridDesc_E0_E1_K_E2,
typename BGridDesc_E_N_Ho_Wo, typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo, typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
...@@ -69,18 +69,18 @@ __global__ void ...@@ -69,18 +69,18 @@ __global__ void
kernel_gemm_dlops_v2(const FloatAB* __restrict__ p_a_grid, kernel_gemm_dlops_v2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_e0_e1_k_grid_desc, const void CONSTANT* p_a_e0_e1_k_e2_grid_desc,
const void CONSTANT* p_b_e0_e1_n_ho_wo_grid_desc, const void CONSTANT* p_b_e0_e1_n_ho_wo_e2_grid_desc,
const void CONSTANT* p_c_k_n_ho_wo_grid_desc, const void CONSTANT* p_c_k_n_ho_wo_grid_desc,
const void CONSTANT* p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor) const void CONSTANT* p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor)
{ {
// first cast void CONSTANT void* to void* // first cast void CONSTANT void* to void*
// second cast void* to Desc* // second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4) // the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_e0_e1_k_grid_desc = *reinterpret_cast<const AGridDesc_E0_E1_K*>( const auto a_e0_e1_k_e2_grid_desc = *reinterpret_cast<const AGridDesc_E0_E1_K_E2*>(
cast_pointer_to_generic_address_space(p_a_e0_e1_k_grid_desc)); cast_pointer_to_generic_address_space(p_a_e0_e1_k_e2_grid_desc));
const auto b_e0_e1_n_ho_wo_grid_desc = *reinterpret_cast<const BGridDesc_E_N_Ho_Wo*>( const auto b_e0_e1_n_ho_wo_e2_grid_desc = *reinterpret_cast<const BGridDesc_E0_E1_N_Ho_Wo_E2*>(
cast_pointer_to_generic_address_space(p_b_e0_e1_n_ho_wo_grid_desc)); cast_pointer_to_generic_address_space(p_b_e0_e1_n_ho_wo_e2_grid_desc));
const auto c_k_n_ho_wo_grid_desc = *reinterpret_cast<const CGridDesc_K_N_Ho_Wo*>( const auto c_k_n_ho_wo_grid_desc = *reinterpret_cast<const CGridDesc_K_N_Ho_Wo*>(
cast_pointer_to_generic_address_space(p_c_k_n_ho_wo_grid_desc)); cast_pointer_to_generic_address_space(p_c_k_n_ho_wo_grid_desc));
...@@ -93,8 +93,8 @@ __global__ void ...@@ -93,8 +93,8 @@ __global__ void
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
a_e0_e1_k_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_ho_wo_grid_desc, c_k_n_ho_wo_grid_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
...@@ -106,10 +106,11 @@ template <index_t BlockSize, ...@@ -106,10 +106,11 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGlobalDesc_E0_E1_K, typename AGlobalDesc_E0_E1_K_E2,
typename BGlobalDesc_E0_E1_N_Ho_Wo, typename BGlobalDesc_E0_E1_N_Ho_Wo_E2,
typename CGlobalDesc_K_N_Ho_Wo, typename CGlobalDesc_K_N_Ho_Wo,
index_t E1, index_t E1,
index_t E2,
index_t KPerBlock, index_t KPerBlock,
index_t HoPerBlock, index_t HoPerBlock,
index_t WoPerBlock, index_t WoPerBlock,
...@@ -118,13 +119,13 @@ template <index_t BlockSize, ...@@ -118,13 +119,13 @@ template <index_t BlockSize,
index_t HoPerThread, index_t HoPerThread,
index_t WoPerThread, index_t WoPerThread,
index_t EPerThread, index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E0_E1_K, typename ABlockTransferThreadSliceLengths_E0_E1_K_E2,
typename ABlockTransferThreadClusterLengths_E0_E1_K, typename ABlockTransferThreadClusterLengths_E0_E1_K_E2,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K, index_t ABlockTransferDstScalarPerVector_E2,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
...@@ -145,20 +146,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -145,20 +146,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr auto max_lds_align = constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_e0_e1_k_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_e0_e1_k_e2_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<E1>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(I1, Number<E1>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size = math::integer_least_multiple(
math::integer_least_multiple(a_e0_e1_k_block_desc.GetElementSpaceSize(), max_lds_align); a_e0_e1_k_e2_block_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(FloatAB); return a_block_space_size * sizeof(FloatAB);
} }
...@@ -168,27 +169,27 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -168,27 +169,27 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AGlobalDesc_E0_E1_K& a_e0_e1_k_global_desc, const AGlobalDesc_E0_E1_K_E2& a_e0_e1_k_e2_global_desc,
const BGlobalDesc_E0_E1_N_Ho_Wo& b_e0_e1_n_ho_wo_global_desc, const BGlobalDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_global_desc,
const CGlobalDesc_K_N_Ho_Wo& c_k_n_ho_wo_global_desc, const CGlobalDesc_K_N_Ho_Wo& c_k_n_ho_wo_global_desc,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k_global_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k_e2_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e0_e1_n_ho_wo_global_desc.GetElementSpaceSize()); p_b_global, b_e0_e1_n_ho_wo_e2_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize()); p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
static_assert(E1 % EPerBlock == 0, ""); static_assert(E1 % EPerBlock == 0, "");
// const auto E = a_e0_e1_k_global_desc.GetLength(I0); // const auto E = a_e0_e1_k_e2_global_desc.GetLength(I0);
// const auto K = a_e0_e1_k_global_desc.GetLength(I1); // const auto K = a_e0_e1_k_e2_global_desc.GetLength(I1);
// const auto N = b_e0_e1_n_ho_wo_global_desc.GetLength(I1); // const auto N = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I1);
const auto Ho = b_e0_e1_n_ho_wo_global_desc.GetLength(I3); const auto Ho = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I3);
const auto Wo = b_e0_e1_n_ho_wo_global_desc.GetLength(I4); const auto Wo = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I4);
// divide block work by [M, N] // divide block work by [M, N]
#if 1 #if 1
...@@ -217,39 +218,44 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -217,39 +218,44 @@ struct GridwiseGemmDlops_km_kn_mn_v3
#endif #endif
// lds max alignment // lds max alignment
constexpr auto max_lds_align = constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_e0_e1_k_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_e0_e1_k_e2_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<I1>{}, Number<E1>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<I1>{}, Number<E1>{}, Number<KPerBlock>{}, Number<E2>{}),
max_lds_align);
constexpr auto a_e1_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_e1_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple( constexpr auto b_e1_n_ho_wo_e2_block_desc =
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<EPerBlock>{},
Number<1>{},
Number<HoPerBlock>{},
Number<WoPerBlock>{},
Number<E2>{}));
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
constexpr auto a_e1_k_e2_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize, BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_e1_k_block_desc), decltype(a_e1_k_e2_block_desc),
decltype(b_e1_n_ho_wo_block_desc), decltype(b_e1_n_ho_wo_e2_block_desc),
decltype(c_k_n_ho_wo_thread_desc), decltype(c_k_n_ho_wo_thread_desc),
EPerThread, EPerThread,
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_E2>{};
ABlockTransferDstScalarPerVector_K>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); auto c_thread_mtx_index =
blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id());
const auto k_thread_id = c_thread_mtx_index.k; const auto k_thread_id = c_thread_mtx_index.k;
const auto ho_thread_id = c_thread_mtx_index.h; const auto ho_thread_id = c_thread_mtx_index.h;
...@@ -268,49 +274,53 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -268,49 +274,53 @@ struct GridwiseGemmDlops_km_kn_mn_v3
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<I1, E1, KPerBlock>, Sequence<I1, E1, KPerBlock, E2>,
ABlockTransferThreadSliceLengths_E0_E1_K, ABlockTransferThreadSliceLengths_E0_E1_K_E2,
ABlockTransferThreadClusterLengths_E0_E1_K, ABlockTransferThreadClusterLengths_E0_E1_K_E2,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_e0_e1_k_global_desc), decltype(a_e0_e1_k_e2_global_desc),
decltype(a_e0_e1_k_block_desc), decltype(a_e0_e1_k_e2_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, // ABlockTransferDstAccessOrder Sequence<0, 1, 2, 3>, // ABlockTransferDstAccessOrder
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, // ABlockTransferDstVectorDim 3, // ABlockTransferDstVectorDim
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K, ABlockTransferDstScalarPerVector_E2,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_e0_e1_k_global_desc, true>(a_e0_e1_k_e2_global_desc,
make_multi_index(0, 0, k_block_data_on_global), make_multi_index(0, 0, k_block_data_on_global, 0),
a_e0_e1_k_block_desc, a_e0_e1_k_e2_block_desc,
make_multi_index(0, 0, 0)); make_multi_index(0, 0, 0, 0));
constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0, 0);
constexpr auto b_e0_e1_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( constexpr auto b_e0_e1_n_ho_wo_e2_thread_desc =
I1, Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<EPerBlock>{},
auto b_threadwise_transfer = Number<1>{},
ThreadwiseTensorSliceTransfer_v2<FloatAB, Number<HoPerThread>{},
FloatAB, Number<WoPerThread>{},
decltype(b_e0_e1_n_ho_wo_global_desc), Number<E2>{}));
decltype(b_e0_e1_n_ho_wo_thread_desc),
Sequence<I1, EPerBlock, 1, HoPerThread, WoPerThread>, auto b_threadwise_transfer = ThreadwiseTensorSliceTransfer_v2<
BBlockTransferSrcAccessOrder, FloatAB,
BBlockTransferSrcVectorDim, FloatAB,
BBlockTransferSrcScalarPerVector, decltype(b_e0_e1_n_ho_wo_e2_global_desc),
1, decltype(b_e0_e1_n_ho_wo_e2_thread_desc),
true>( Sequence<I1, EPerBlock, 1, HoPerThread, WoPerThread, E2>,
b_e0_e1_n_ho_wo_global_desc, BBlockTransferSrcAccessOrder,
make_multi_index(0, 0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
1,
true>(b_e0_e1_n_ho_wo_e2_global_desc,
make_multi_index(0, 0, 0, ho_thread_data_on_global, wo_thread_data_on_global, 0));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_shared_block, a_e0_e1_k_block_desc.GetElementSpaceSize()); p_shared_block, a_e0_e1_k_e2_block_desc.GetElementSpaceSize());
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
...@@ -325,20 +335,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -325,20 +335,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{} Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); .Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto b_thread_slice_copy_step = make_multi_index(0, EPerBlock, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(0, EPerBlock, 0, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e0_e1_k_global_step_hacks = AGlobalStepHacks{}; constexpr auto a_e0_e1_k_e2_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e0_e1_n_ho_wo_global_step_hacks = BGlobalStepHacks{}; constexpr auto b_e0_e1_n_ho_wo_e2_global_step_hacks = BGlobalStepHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAB, FloatAB,
b_e0_e1_n_ho_wo_thread_desc.GetElementSpaceSize(), b_e0_e1_n_ho_wo_e2_thread_desc.GetElementSpaceSize(),
true> true>
b_thread_even_buf, b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
const auto E0 = b_e0_e1_n_ho_wo_global_desc.GetLength(I0); const auto E0 = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I0);
index_t e0_block_data_begin = 0; index_t e0_block_data_begin = 0;
...@@ -347,16 +357,16 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -347,16 +357,16 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: preload data // LDS double buffer: preload data
{ {
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_e0_e1_k_global_desc, a_global_buf, a_e0_e1_k_global_step_hacks); a_e0_e1_k_e2_global_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_thread_desc, b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e0_e1_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_e2_global_step_hacks);
a_blockwise_copy.RunWrite(a_e0_e1_k_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_e0_e1_k_e2_block_desc, a_block_buf);
} }
__syncthreads(); __syncthreads();
...@@ -370,36 +380,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -370,36 +380,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3
do do
{ {
// even iteration // even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_thread_desc, b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e0_e1_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_e2_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window // TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0)); blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_thread_desc, b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e0_e1_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_e2_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0)); blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0, 0));
e1_block_data_begin += 2 * EPerBlock; e1_block_data_begin += 2 * EPerBlock;
...@@ -409,20 +419,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -409,20 +419,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_thread_desc, b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e0_e1_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_e2_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0)); blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0, 0));
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
...@@ -433,12 +443,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -433,12 +443,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
} }
a_blockwise_copy.MoveSrcSliceWindow( a_blockwise_copy.MoveSrcSliceWindow(a_e0_e1_k_e2_global_desc,
a_e0_e1_k_global_desc, a_block_slice_copy_step, AGlobalMoveSliceWindowStepHacks{}); a_block_slice_copy_step,
AGlobalMoveSliceWindowStepHacks{});
blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - EPerBlock), 0)); blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - EPerBlock), 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
e0_block_data_begin += 1; e0_block_data_begin += 1;
......
...@@ -9,16 +9,17 @@ namespace ck { ...@@ -9,16 +9,17 @@ namespace ck {
// C[M, N] += transpose(A[K, M]) * B[K, N] // C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data // Element of matrix can be vectorized data
// Assume: // Assume:
// 1. AThreadDesc_E_K, BThreadDesc_E_N_Ho_Wo, CThreadDesc_K_N_Ho_Wo are known at compile-time // 1. AThreadDesc_E1_K_E2, BThreadDesc_E1_N_Ho_Wo_E2, CThreadDesc_K_N_Ho_Wo are known at
// compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time // 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA, template <typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename AThreadDesc_E_K, typename AThreadDesc_E1_K_E2,
typename BThreadDesc_E_N_Ho_Wo, typename BThreadDesc_E1_N_Ho_Wo_E2,
typename CThreadDesc_K_N_Ho_Wo, typename CThreadDesc_K_N_Ho_Wo,
typename enable_if<AThreadDesc_E_K::IsKnownAtCompileTime() && typename enable_if<AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
BThreadDesc_E_N_Ho_Wo::IsKnownAtCompileTime() && BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(), CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemmDlops_km_kn_mn_v3 struct ThreadwiseGemmDlops_km_kn_mn_v3
...@@ -38,8 +39,8 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 ...@@ -38,8 +39,8 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
COriginIdx) COriginIdx)
{ {
static_assert(AThreadDesc_E_K::IsKnownAtCompileTime() && static_assert(AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
BThreadDesc_E_N_Ho_Wo::IsKnownAtCompileTime() && BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(), CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
...@@ -54,18 +55,19 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 ...@@ -54,18 +55,19 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value && is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type"); "wrong! inconsistent type");
constexpr index_t Vec = 2;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto E = AThreadDesc_E_K{}.GetLength(I0); constexpr auto E1 = AThreadDesc_E1_K_E2{}.GetLength(I0);
constexpr auto K = AThreadDesc_E_K{}.GetLength(I1); constexpr auto K = AThreadDesc_E1_K_E2{}.GetLength(I1);
constexpr auto E2 = AThreadDesc_E1_K_E2{}.GetLength(I2);
static_assert(E1 == 4 && E2 == 4, "");
constexpr auto H = BThreadDesc_E_N_Ho_Wo{}.GetLength(I2); constexpr auto H = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
constexpr auto W = BThreadDesc_E_N_Ho_Wo{}.GetLength(I3); constexpr auto W = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
...@@ -74,22 +76,23 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 ...@@ -74,22 +76,23 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
static_for<0, K, 1>{}([&](auto k) { static_for<0, K, 1>{}([&](auto k) {
static_for<0, H, 1>{}([&](auto h) { static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) { static_for<0, W, 1>{}([&](auto w) {
static_for<0, E, Vec>{}([&](auto e) { static_for<0, E1, 1>{}([&](auto e) {
vector_type<FloatA, Vec> a_vec; vector_type<FloatA, E2> a_vec;
vector_type<FloatB, Vec> b_vec; vector_type<FloatB, E2> b_vec;
static_for<0, Vec, 1>{}([&](auto v) { static_for<0, E2, 1>{}([&](auto v) {
constexpr index_t a_offset = AThreadDesc_E_K{}.CalculateOffset( constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
a_origin_idx + make_tuple(e + v, k)); a_origin_idx + make_tuple(e, k, v));
constexpr index_t b_offset = BThreadDesc_E_N_Ho_Wo{}.CalculateOffset( constexpr index_t b_offset =
b_origin_idx + make_tuple(e + v, 0, h, w)); BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
b_origin_idx + make_tuple(e, 0, h, w, v));
a_vec.template AsType<FloatA>()(v) = a_buf[Number<a_offset>{}]; a_vec.template AsType<FloatA>()(v) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(v) = b_buf[Number<b_offset>{}]; b_vec.template AsType<FloatB>()(v) = b_buf[Number<b_offset>{}];
}); });
using a_vector_t = typename vector_type<FloatA, Vec>::type; using a_vector_t = typename vector_type<FloatA, E2>::type;
using b_vector_t = typename vector_type<FloatB, Vec>::type; using b_vector_t = typename vector_type<FloatB, E2>::type;
constexpr index_t c_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( constexpr index_t c_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h, w)); c_origin_idx + make_tuple(k, 0, h, w));
......
...@@ -102,26 +102,27 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -102,26 +102,27 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 32; constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 8; constexpr index_t WoPerBlock = 8;
constexpr index_t E1 = 16; constexpr index_t E1 = 4;
constexpr index_t EPerBlock = 16; constexpr index_t E2 = 4;
constexpr index_t EPerBlock = 4;
constexpr index_t KPerThread = KPerBlock; constexpr index_t KPerThread = 4;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock; constexpr index_t EPerThread = 4;
using ABlockTransferThreadSliceLengths_E0_E1_K = Sequence<1, 4, 1>; using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 1, 1, 4>;
using ABlockTransferThreadClusterLengths_E0_E1_K = Sequence<1, 4, 16>; using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, 4, 16, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 4; constexpr index_t ABlockTransferSrcScalarPerVector_E2 = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1; constexpr index_t ABlockTransferDstScalarPerVector_E2 = 1;
constexpr index_t BThreadTransferSrcScalarPerVector_E = 4; constexpr index_t BThreadTransferSrcScalarPerVector_E2 = 1;
constexpr index_t CThreadTransferDstScalarPerVector_K = 4; constexpr index_t CThreadTransferDstScalarPerVector_K = 1;
#endif #endif
constexpr auto conv_driver = constexpr auto conv_driver =
...@@ -131,6 +132,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -131,6 +132,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
TAcc, TAcc,
TOut, TOut,
E1, E1,
E2,
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
...@@ -139,11 +141,11 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -139,11 +141,11 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
HoPerThread, HoPerThread,
WoPerThread, WoPerThread,
EPerThread, EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K, ABlockTransferThreadSliceLengths_E0_E1_K_E2,
ABlockTransferThreadClusterLengths_E0_E1_K, ABlockTransferThreadClusterLengths_E0_E1_K_E2,
ABlockTransferSrcScalarPerVector_E, ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_K, ABlockTransferDstScalarPerVector_E2,
BThreadTransferSrcScalarPerVector_E, BThreadTransferSrcScalarPerVector_E2,
CThreadTransferDstScalarPerVector_K>{}; CThreadTransferDstScalarPerVector_K>{};
const auto ave_time = const auto ave_time =
......
...@@ -11,6 +11,7 @@ template <ck::index_t BlockSize, ...@@ -11,6 +11,7 @@ template <ck::index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
ck::index_t E1, ck::index_t E1,
ck::index_t E2,
ck::index_t KPerBlock, ck::index_t KPerBlock,
ck::index_t HoPerBlock, ck::index_t HoPerBlock,
ck::index_t WoPerBlock, ck::index_t WoPerBlock,
...@@ -19,11 +20,11 @@ template <ck::index_t BlockSize, ...@@ -19,11 +20,11 @@ template <ck::index_t BlockSize,
ck::index_t HoPerThread, ck::index_t HoPerThread,
ck::index_t WoPerThread, ck::index_t WoPerThread,
ck::index_t EPerThread, ck::index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E_K, typename ABlockTransferThreadSliceLengths_E0_E1_K_E2,
typename ABlockTransferThreadClusterLengths_E_K, typename ABlockTransferThreadClusterLengths_E0_E1_K_E2,
ck::index_t ABlockTransferSrcScalarPerVector_E, ck::index_t ABlockTransferSrcScalarPerVector_E2,
ck::index_t ABlockTransferDstScalarPerVector_K, ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_E, ck::index_t BThreadTransferSrcScalarPerVector_E2,
ck::index_t CThreadTransferDstScalarPerVector_K> ck::index_t CThreadTransferDstScalarPerVector_K>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
{ {
...@@ -93,7 +94,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -93,7 +94,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
<< std::endl; << std::endl;
const auto E = C0 * Y * X * C1; const auto E = C0 * Y * X * C1;
const auto E0 = E / E1; const auto E0 = E / (E1 * E2);
// weight tensor // weight tensor
const auto a_e_k_grid_desc = transform_tensor_descriptor( const auto a_e_k_grid_desc = transform_tensor_descriptor(
...@@ -103,11 +104,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -103,11 +104,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto a_e0_e1_k_grid_desc = transform_tensor_descriptor( const auto a_e0_e1_k_e2_grid_desc =
a_e_k_grid_desc, transform_tensor_descriptor(a_e_k_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)), make_pass_through_transform(K)), make_tuple(make_unmerge_transform(make_tuple(E0, E1, E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_pass_through_transform(K)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// input tensor // input tensor
const auto in_n_c0_hip_wip_c1_global_desc = transform_tensor_descriptor( const auto in_n_c0_hip_wip_c1_global_desc = transform_tensor_descriptor(
...@@ -141,14 +143,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -141,14 +143,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_tuple(Sequence<1, 2, 4, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}), make_tuple(Sequence<1, 2, 4, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto b_e0_e1_n_ho_wo_grid_desc = transform_tensor_descriptor( const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
b_e_n_ho_wo_grid_desc, b_e_n_ho_wo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)), make_tuple(make_unmerge_transform(make_tuple(E0, E1, E2)),
make_pass_through_transform(N), make_pass_through_transform(N),
make_pass_through_transform(Hop), make_pass_through_transform(Hop),
make_pass_through_transform(Wop)), make_pass_through_transform(Wop)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); make_tuple(Sequence<0, 1, 5>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// output tensor // output tensor
const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor( const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor(
...@@ -169,27 +171,33 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -169,27 +171,33 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
} }
// hack to control index calculation when iterating over a_k_m_global tensor // hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_e0_e1_k_global_step_hacks = make_tuple( constexpr auto a_e0_e1_k_e2_global_step_hacks =
make_tuple( make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0>{},
make_tuple( Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
constexpr auto a_e0_e1_k_global_move_slice_window_step_hack = Sequence<0, 0, 0, 0, 0>{}; constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack = Sequence<0, 0, 0, 0, 0>{};
constexpr auto b_e0_e1_n_ho_wo_global_step_hacks = make_tuple( constexpr auto b_e0_e1_n_ho_wo_e2_global_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto b_e0_e1_n_ho_wo_global_move_slice_window_step_hack = constexpr auto b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
...@@ -211,10 +219,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -211,10 +219,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
decltype(a_e0_e1_k_grid_desc), decltype(a_e0_e1_k_e2_grid_desc),
decltype(b_e0_e1_n_ho_wo_grid_desc), decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
decltype(c_k_n_hop_wop_grid_desc), decltype(c_k_n_hop_wop_grid_desc),
E1, E1,
E2,
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
...@@ -223,31 +232,31 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -223,31 +232,31 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
HoPerThread, HoPerThread,
WoPerThread, WoPerThread,
EPerThread, EPerThread,
ABlockTransferThreadSliceLengths_E_K, ABlockTransferThreadSliceLengths_E0_E1_K_E2,
ABlockTransferThreadClusterLengths_E_K, ABlockTransferThreadClusterLengths_E0_E1_K_E2,
Sequence<2, 0, 1>, Sequence<2, 0, 1, 3>,
Sequence<2, 0, 1>, Sequence<2, 0, 1, 3>,
1, 3,
ABlockTransferSrcScalarPerVector_E, ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_K, ABlockTransferDstScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 3, 4, 1>, Sequence<0, 2, 3, 4, 1, 5>,
1, 5,
BThreadTransferSrcScalarPerVector_E, BThreadTransferSrcScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<0, 2, 3, 1>, Sequence<2, 3, 1, 0>,
0, 0,
CThreadTransferDstScalarPerVector_K, CThreadTransferDstScalarPerVector_K,
decltype(a_e0_e1_k_global_step_hacks), decltype(a_e0_e1_k_e2_global_step_hacks),
decltype(b_e0_e1_n_ho_wo_global_step_hacks), decltype(b_e0_e1_n_ho_wo_e2_global_step_hacks),
decltype(c_k_n_ho_wo_global_tensor_step_hacks), decltype(c_k_n_ho_wo_global_tensor_step_hacks),
decltype(a_e0_e1_k_global_move_slice_window_step_hack), decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_ho_wo_global_move_slice_window_step_hack)>; decltype(b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack)>;
using AGridDesc_E0_E1_K = decltype(a_e0_e1_k_grid_desc); using AGridDesc_E0_E1_K_E2 = decltype(a_e0_e1_k_e2_grid_desc);
using BGridDesc_E0_E1_N_Ho_Wo = decltype(b_e0_e1_n_ho_wo_grid_desc); using BGridDesc_E0_E1_N_Ho_Wo_E2 = decltype(b_e0_e1_n_ho_wo_e2_grid_desc);
using CGridDesc_K_N_Ho_Wo = decltype(c_k_n_hop_wop_grid_desc); using CGridDesc_K_N_Ho_Wo = decltype(c_k_n_hop_wop_grid_desc);
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
...@@ -276,8 +285,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -276,8 +285,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true, true,
...@@ -291,8 +300,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -291,8 +300,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e0_e1_k_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
...@@ -302,8 +311,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -302,8 +311,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true, true,
...@@ -317,8 +326,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -317,8 +326,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e0_e1_k_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
...@@ -328,8 +337,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -328,8 +337,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false, false,
...@@ -343,8 +352,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -343,8 +352,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e0_e1_k_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
...@@ -354,8 +363,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -354,8 +363,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false, false,
...@@ -369,22 +378,22 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -369,22 +378,22 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e0_e1_k_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
return ave_time; return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_e0_e1_k_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K)); DeviceMem a_e0_e1_k_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K_E2));
DeviceMem b_e0_e1_n_ho_wo_grid_desc_dev_buf(sizeof(BGridDesc_E0_E1_N_Ho_Wo)); DeviceMem b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf(sizeof(BGridDesc_E0_E1_N_Ho_Wo_E2));
DeviceMem c_k_n_hop_wop_grid_desc_dev_buf(sizeof(CGridDesc_K_N_Ho_Wo)); DeviceMem c_k_n_hop_wop_grid_desc_dev_buf(sizeof(CGridDesc_K_N_Ho_Wo));
DeviceMem c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf( DeviceMem c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf(
sizeof(CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo)); sizeof(CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo));
a_e0_e1_k_grid_desc_dev_buf.ToDevice(&a_e0_e1_k_grid_desc); a_e0_e1_k_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k_e2_grid_desc);
b_e0_e1_n_ho_wo_grid_desc_dev_buf.ToDevice(&b_e0_e1_n_ho_wo_grid_desc); b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.ToDevice(&b_e0_e1_n_ho_wo_e2_grid_desc);
c_k_n_hop_wop_grid_desc_dev_buf.ToDevice(&c_k_n_hop_wop_grid_desc); c_k_n_hop_wop_grid_desc_dev_buf.ToDevice(&c_k_n_hop_wop_grid_desc);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice( c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice(
&c_blockid_to_k_n_ho_wo_block_cluster_adaptor); &c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
...@@ -397,8 +406,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -397,8 +406,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true, true,
...@@ -414,9 +423,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -414,9 +423,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
...@@ -428,8 +437,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -428,8 +437,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true, true,
...@@ -445,9 +454,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -445,9 +454,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
...@@ -459,8 +468,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -459,8 +468,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false, false,
...@@ -476,9 +485,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -476,9 +485,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
...@@ -490,8 +499,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -490,8 +499,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K>, remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false, false,
...@@ -507,9 +516,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -507,9 +516,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
......
...@@ -52,6 +52,7 @@ REPEAT=$6 ...@@ -52,6 +52,7 @@ REPEAT=$6
#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 #./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 3 3 1080 1920 1 1 1 1 1 1 1 1 ./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 3 3 1080 1920 1 1 1 1 1 1 1 1
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 1 1 8 8 1 1 1 1 1 1 1 1
################################################ layout algo verify init log repeat M___ N___ K___ ################################################ layout algo verify init log repeat M___ N___ K___
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 #./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment