Commit 10fdada7 authored by Jing Zhang's avatar Jing Zhang
Browse files

rename e0_e1

parent 95228cd7
...@@ -12,14 +12,16 @@ template <index_t BlockSize, ...@@ -12,14 +12,16 @@ template <index_t BlockSize,
typename BlockMatrixA, typename BlockMatrixA,
typename BlockMatrixB, typename BlockMatrixB,
typename ThreadMatrixC, typename ThreadMatrixC,
index_t KPerThread,
index_t HPerThread,
index_t WPerThread,
index_t EPerThreadLoop, index_t EPerThreadLoop,
index_t ThreadGemmADataPerRead_K, index_t ThreadGemmADataPerRead_K,
index_t ThreadGemmBDataPerRead_W> index_t ThreadGemmBDataPerRead_W>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{ {
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
struct MatrixIndex struct MatrixIndex
{ {
index_t k; index_t k;
...@@ -27,6 +29,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -27,6 +29,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
index_t w; index_t w;
}; };
static constexpr auto KPerThread = ThreadMatrixC{}.GetLength(I0);
static constexpr auto HPerThread = ThreadMatrixC{}.GetLength(I2);
static constexpr auto WPerThread = ThreadMatrixC{}.GetLength(I3);
// HACK: fix this @Jing Zhang // HACK: fix this @Jing Zhang
static constexpr index_t KPerThreadSubC = 4; static constexpr index_t KPerThreadSubC = 4;
...@@ -39,16 +45,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -39,16 +45,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BlockMatrixA,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_K,
1>;
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)} a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
...@@ -58,11 +54,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -58,11 +54,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
ThreadMatrixC::IsKnownAtCompileTime(), ThreadMatrixC::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n"); "wrong! K dimension not consistent\n");
...@@ -88,11 +79,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -88,11 +79,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{ {
constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{}); constexpr index_t HPerBlock = BlockMatrixB{}.GetLength(Number<2>{});
constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{}); constexpr index_t WPerBlock = BlockMatrixB{}.GetLength(Number<3>{});
constexpr auto num_w_threads = W / WPerThread; constexpr auto num_w_threads = WPerBlock / WPerThread;
constexpr auto num_h_threads = H / HPerThread; constexpr auto num_h_threads = HPerBlock / HPerThread;
constexpr auto num_hw_threads = num_w_threads * num_h_threads; constexpr auto num_hw_threads = num_w_threads * num_h_threads;
index_t k_thread_id = thread_id / num_hw_threads; index_t k_thread_id = thread_id / num_hw_threads;
...@@ -115,8 +106,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -115,8 +106,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value && is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type"); "wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto a_block_mtx = BlockMatrixA{}; constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto EPerBlock = a_block_mtx.GetLength(I0); constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
...@@ -166,8 +155,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -166,8 +155,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
} }
template <typename ABlockSliceMoveStepIdx> template <typename ABlockSliceMoveStepIdx>
__device__ void MoveASliceWindow(const BlockMatrixA&, __device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{ {
a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx); a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx);
} }
...@@ -175,6 +163,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -175,6 +163,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
private: private:
MatrixIndex c_thread_begin_mtx_idx_; MatrixIndex c_thread_begin_mtx_idx_;
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BlockMatrixA,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_K,
1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
}; };
......
...@@ -15,24 +15,24 @@ namespace ck { ...@@ -15,24 +15,24 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AEKGridDesc, typename AGridDesc_E0_E1_K,
typename BENHoWoGridDesc, typename BGridDesc_E_N_Ho_Wo,
typename CKNHoWoGridDesc, typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToKNHoWoBlockClusterAdaptor, typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_dlops_v2( kernel_gemm_dlops_v2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AEKGridDesc a_e_k_grid_desc, const AGridDesc_E0_E1_K a_e0_e1_k_grid_desc,
const BENHoWoGridDesc b_e_n_ho_wo_grid_desc, const BGridDesc_E_N_Ho_Wo b_e0_e1_n_ho_wo_grid_desc,
const CKNHoWoGridDesc c_k_n_ho_wo_grid_desc, const CGridDesc_K_N_Ho_Wo c_k_n_ho_wo_grid_desc,
const CBlockIdToKNHoWoBlockClusterAdaptor c_blockid_to_k_n_ho_wo_block_cluster_adaptor) const CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
c_blockid_to_k_n_ho_wo_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -43,8 +43,8 @@ __global__ void ...@@ -43,8 +43,8 @@ __global__ void
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
a_e_k_grid_desc, a_e0_e1_k_grid_desc,
b_e_n_ho_wo_grid_desc, b_e0_e1_n_ho_wo_grid_desc,
c_k_n_ho_wo_grid_desc, c_k_n_ho_wo_grid_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
...@@ -56,10 +56,10 @@ __global__ void ...@@ -56,10 +56,10 @@ __global__ void
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AEKGridDesc, typename AGridDesc_E0_E1_K,
typename BENHoWoGridDesc, typename BGridDesc_E_N_Ho_Wo,
typename CKNHoWoGridDesc, typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToKNHoWoBlockClusterAdaptor, typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
...@@ -69,19 +69,19 @@ __global__ void ...@@ -69,19 +69,19 @@ __global__ void
kernel_gemm_dlops_v2(const FloatAB* __restrict__ p_a_grid, kernel_gemm_dlops_v2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_e_k_grid_desc, const void CONSTANT* p_a_e0_e1_k_grid_desc,
const void CONSTANT* p_b_e_n_ho_wo_grid_desc, const void CONSTANT* p_b_e0_e1_n_ho_wo_grid_desc,
const void CONSTANT* p_c_k_n_ho_wo_grid_desc, const void CONSTANT* p_c_k_n_ho_wo_grid_desc,
const void CONSTANT* p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor) const void CONSTANT* p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor)
{ {
// first cast void CONSTANT void* to void* // first cast void CONSTANT void* to void*
// second cast void* to Desc* // second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4) // the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_e_k_grid_desc = *reinterpret_cast<const AEKGridDesc*>( const auto a_e0_e1_k_grid_desc = *reinterpret_cast<const AGridDesc_E0_E1_K*>(
cast_pointer_to_generic_address_space(p_a_e_k_grid_desc)); cast_pointer_to_generic_address_space(p_a_e0_e1_k_grid_desc));
const auto b_e_n_ho_wo_grid_desc = *reinterpret_cast<const BENHoWoGridDesc*>( const auto b_e0_e1_n_ho_wo_grid_desc = *reinterpret_cast<const BGridDesc_E_N_Ho_Wo*>(
cast_pointer_to_generic_address_space(p_b_e_n_ho_wo_grid_desc)); cast_pointer_to_generic_address_space(p_b_e0_e1_n_ho_wo_grid_desc));
const auto c_k_n_ho_wo_grid_desc = *reinterpret_cast<const CKNHoWoGridDesc*>( const auto c_k_n_ho_wo_grid_desc = *reinterpret_cast<const CGridDesc_K_N_Ho_Wo*>(
cast_pointer_to_generic_address_space(p_c_k_n_ho_wo_grid_desc)); cast_pointer_to_generic_address_space(p_c_k_n_ho_wo_grid_desc));
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
...@@ -93,8 +93,8 @@ __global__ void ...@@ -93,8 +93,8 @@ __global__ void
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
a_e_k_grid_desc, a_e0_e1_k_grid_desc,
b_e_n_ho_wo_grid_desc, b_e0_e1_n_ho_wo_grid_desc,
c_k_n_ho_wo_grid_desc, c_k_n_ho_wo_grid_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
...@@ -106,9 +106,9 @@ template <index_t BlockSize, ...@@ -106,9 +106,9 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGlobalDesc, typename AGlobalDesc_E0_E1_K,
typename BGlobalDesc, typename BGlobalDesc_E0_E1_N_Ho_Wo,
typename CGlobalDesc, typename CGlobalDesc_K_N_Ho_Wo,
index_t KPerBlock, index_t KPerBlock,
index_t HoPerBlock, index_t HoPerBlock,
index_t WoPerBlock, index_t WoPerBlock,
...@@ -148,12 +148,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -148,12 +148,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_e1_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_e1_k_block_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(FloatAB); return a_block_space_size * sizeof(FloatAB);
} }
...@@ -163,9 +163,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -163,9 +163,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AGlobalDesc& a_e_k_global_desc, const AGlobalDesc_E0_E1_K& a_e0_e1_k_global_desc,
const BGlobalDesc& b_e_n_ho_wo_global_desc, const BGlobalDesc_E0_E1_N_Ho_Wo& b_e0_e1_n_ho_wo_global_desc,
const CGlobalDesc& c_k_n_ho_wo_global_desc, const CGlobalDesc_K_N_Ho_Wo& c_k_n_ho_wo_global_desc,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
...@@ -175,18 +175,18 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -175,18 +175,18 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e_k_global_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize()); p_b_global, b_e0_e1_n_ho_wo_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize()); p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
// const auto E = a_e_k_global_desc.GetLength(I0); // const auto E = a_e0_e1_k_global_desc.GetLength(I0);
// const auto K = a_e_k_global_desc.GetLength(I1); // const auto K = a_e0_e1_k_global_desc.GetLength(I1);
// const auto N = b_e_n_ho_wo_global_desc.GetLength(I1); // const auto N = b_e0_e1_n_ho_wo_global_desc.GetLength(I1);
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2); const auto Ho = b_e0_e1_n_ho_wo_global_desc.GetLength(I2);
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3); const auto Wo = b_e0_e1_n_ho_wo_global_desc.GetLength(I3);
// divide block work by [M, N] // divide block work by [M, N]
#if 0 #if 0
...@@ -220,15 +220,15 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -220,15 +220,15 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_e1_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto a_e2_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple( constexpr auto b_e2_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{})); Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
...@@ -240,12 +240,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -240,12 +240,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize, BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_e_k_block_desc), decltype(a_e2_k_block_desc),
decltype(b_e_n_ho_wo_block_desc), decltype(b_e2_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc), decltype(c_k_n_ho_wo_thread_desc),
KPerThread,
HoPerThread,
WoPerThread,
EPerThread, EPerThread,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K>{}; ABlockTransferDstScalarPerVector_K>{};
...@@ -275,8 +272,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -275,8 +272,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_e_k_global_desc), decltype(a_e0_e1_k_global_desc),
decltype(a_e_k_desc), decltype(a_e1_k_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
...@@ -286,30 +283,30 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -286,30 +283,30 @@ struct GridwiseGemmDlops_km_kn_mn_v3
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_e_k_global_desc, true>(a_e0_e1_k_global_desc,
make_multi_index(0, k_block_data_on_global), make_multi_index(0, k_block_data_on_global),
a_e_k_desc, a_e1_k_block_desc,
make_multi_index(0, 0)); make_multi_index(0, 0));
constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( constexpr auto b_e2_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto b_threadwise_transfer = auto b_threadwise_transfer =
ThreadwiseTensorSliceTransfer_v2<FloatAB, ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB, FloatAB,
decltype(b_e_n_ho_wo_global_desc), decltype(b_e0_e1_n_ho_wo_global_desc),
decltype(b_e_n_ho_wo_thread_desc), decltype(b_e2_n_ho_wo_thread_desc),
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>, Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
1, 1,
true>( true>(
b_e_n_ho_wo_global_desc, b_e0_e1_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_shared_block, a_e_k_desc.GetElementSpaceSize()); p_shared_block, a_e1_k_block_desc.GetElementSpaceSize());
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
...@@ -327,34 +324,29 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -327,34 +324,29 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{}; constexpr auto a_e0_e1_k_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{}; constexpr auto b_e0_e1_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// constexpr auto a_e_k_global_move_slice_window_step_hack =
// AGlobalMoveSliceWindowStepHacks{}; constexpr auto
// b_e_n_ho_wo_global_move_slice_window_step_hack = BGlobalMoveSliceWindowStepHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAB, FloatAB,
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(), b_e2_n_ho_wo_thread_desc.GetElementSpaceSize(),
true> true>
b_thread_even_buf, b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data // LDS double buffer: preload data
{ {
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks); a_blockwise_copy.RunRead(
a_e0_e1_k_global_desc, a_global_buf, a_e0_e1_k_global_step_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e2_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_global_step_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf); a_blockwise_copy.RunWrite(a_e1_k_block_desc, a_block_buf);
} }
__syncthreads(); __syncthreads();
...@@ -368,36 +360,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -368,36 +360,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3
do do
{ {
// even iteration // even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e2_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window // TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e2_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0));
e_block_data_begin += 2 * EPerBlock; e_block_data_begin += 2 * EPerBlock;
...@@ -407,20 +399,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -407,20 +399,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e2_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0));
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
......
...@@ -26,14 +26,6 @@ template <typename FloatA, ...@@ -26,14 +26,6 @@ template <typename FloatA,
struct ThreadwiseGemmDlops_km_kn_mn_v3 struct ThreadwiseGemmDlops_km_kn_mn_v3
{ {
__device__ ThreadwiseGemmDlops_km_kn_mn_v3()
{
static_assert(AThreadDesc_E_K::IsKnownAtCompileTime() &&
BThreadDesc_E_N_Ho_Wo::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
}
template <typename ABuffer, template <typename ABuffer,
typename AOriginIdx, typename AOriginIdx,
typename BBuffer, typename BBuffer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment