Commit 1d478b9e authored by ltqin's avatar ltqin
Browse files

add tail for pipeline

parent c5c32b4d
...@@ -29,8 +29,6 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -29,8 +29,6 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
...@@ -40,30 +38,29 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -40,30 +38,29 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
#define NORMAL_CONFIG 1
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if 1 #if 0
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
#else #else
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>; < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 1, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 7, 1>;
using ADataType = float; using ADataType = float;
using BDataType = float; using BDataType = float;
using CDataType = float; using CDataType = float;
using AccDataType = float; using AccDataType = float;
#endif #endif
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
template <typename DataType> template <typename DataType>
std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix) std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
...@@ -88,10 +85,10 @@ int main(int argc, char* argv[]) ...@@ -88,10 +85,10 @@ int main(int argc, char* argv[])
int nrepeat = 5; int nrepeat = 5;
// GEMM shape // GEMM shape
#if NORMAL_CONFIG #if 0
ck::index_t M = 256; ck::index_t M = 256;
ck::index_t N = 4096; ck::index_t N = 4096;
ck::index_t K = 64; ck::index_t K = 32;
ck::index_t StrideA = 64; ck::index_t StrideA = 64;
ck::index_t StrideB = 64; ck::index_t StrideB = 64;
...@@ -99,10 +96,10 @@ int main(int argc, char* argv[]) ...@@ -99,10 +96,10 @@ int main(int argc, char* argv[])
#else #else
ck::index_t M = 16; ck::index_t M = 16;
ck::index_t N = 16; ck::index_t N = 16;
ck::index_t K = 16; ck::index_t K = 8;
ck::index_t StrideA = 16; ck::index_t StrideA = 8;
ck::index_t StrideB = 16; ck::index_t StrideB = 8;
ck::index_t StrideC = 16; ck::index_t StrideC = 16;
#endif #endif
...@@ -234,7 +231,7 @@ int main(int argc, char* argv[]) ...@@ -234,7 +231,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
#if 0 #if 1
{ {
show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl; show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl;
show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl; show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl;
......
...@@ -12,8 +12,10 @@ template <index_t BlockSize, ...@@ -12,8 +12,10 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename AK0MK1BlockDesc, typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
typename BK0K0BN0N1N2N3K1BlockDesc, typename BK0K0BN0N1N2N3K1BlockDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t MPerXDL, index_t MPerXDL,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
...@@ -28,21 +30,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -28,21 +30,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t KPerBlock = K0PerBlock * KPack;
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t K0PerThread = static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
BK0NK1BlockDesc{}.GetLength(I0) / xdlops_gemm.K0PerXdlops;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
...@@ -122,7 +118,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -122,7 +118,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1() __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1()
{ {
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(), BK0K0BN0N1N2N3K1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
...@@ -235,6 +231,16 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -235,6 +231,16 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
} }
__device__ void MoveABlockSliceWindow()
{
a_thread_copy_.MoveSrcSliceWindow(a_block_desc_m0_m1_m2_k,
make_multi_index(0, 0, 0, K0PerBlock * KPack));
}
__device__ void ResetABlockStartWindow()
{
a_thread_copy_.SetSrcCoord(CalculateAThreadOriginDataIndex());
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
...@@ -244,8 +250,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -244,8 +250,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
// auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
// b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A // read A
......
...@@ -58,17 +58,17 @@ __global__ void ...@@ -58,17 +58,17 @@ __global__ void
c_element_op, c_element_op,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3; ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -146,11 +146,11 @@ __global__ void ...@@ -146,11 +146,11 @@ __global__ void
block_id_grp); block_id_grp);
#endif #endif
#else #else
ignore = gemm_desc_; ignore = gemm_desc_;
ignore = group_count; ignore = group_count;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -203,7 +203,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -203,7 +203,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
#if USING_SKIP_LDS
static constexpr auto MultiK0 = 2;
#else
static constexpr auto MultiK0 = 1;
#endif
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
...@@ -223,13 +227,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -223,13 +227,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<K0PerBlock * MultiK0>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock * MultiK0>{}, Number<MPerBlock>{}, K1),
max_lds_align);
} }
}(); }();
...@@ -352,7 +357,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -352,7 +357,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// TODO move this function into GEMM-pipeline class // TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; const bool has_main_k0_block_loop = (K0 / (2 * K0PerBlock)) > 1;
return has_main_k0_block_loop; return has_main_k0_block_loop;
} }
...@@ -530,16 +535,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -530,16 +535,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock * MultiK0, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -638,8 +640,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -638,8 +640,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3), decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
MXdlPerWave, MXdlPerWave,
...@@ -653,7 +657,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -653,7 +657,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
// gridwise GEMM pipeline // gridwise GEMM pipeline
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock * MultiK0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
// preload data to regiester and LDS // preload data to regiester and LDS
{ {
...@@ -682,9 +686,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -682,9 +686,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
index_t i = 0; index_t i = 0;
do do
{ {
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); // block_sync_lds();
block_sync_lds();
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
...@@ -693,45 +695,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -693,45 +695,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// move windows // move windows
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
block_sync_lds();
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf); b_thread_even_buf);
// block_sync_lds();
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
i += 2; i += 2;
} while(i < (K0BlockMainLoop - 1)); } while(i < (K0BlockMainLoop - 2));
} }
// tail // tail
{ {
// block_sync_lds(); block_sync_lds();
// blockwise_gemm.Run(a_block_buf, b_thread_buf, c_thread_buf); // block_sync_lds();
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf);
blockwise_gemm.ResetABlockStartWindow();
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
// move windows
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf);
// block_sync_lds();
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
} }
} }
#else #else
ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3; ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3;
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
......
...@@ -1163,6 +1163,10 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1163,6 +1163,10 @@ struct ThreadwiseTensorSliceTransfer_v4
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
} }
__device__ void SetSrcCoord(const Index& src_ref_idx)
{
src_ref_coord_ = make_tensor_coordinate(SrcDesc{}, src_ref_idx);
}
private: private:
SrcCoord src_ref_coord_; SrcCoord src_ref_coord_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment