Commit ef6933a2 authored by ltqin's avatar ltqin
Browse files

merge develop and remove hacks

parent 7506342c
......@@ -6,9 +6,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck {
......@@ -140,9 +139,8 @@ template <index_t BlockSize,
index_t MPerXDL,
index_t NPerXDL,
index_t K1Value,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -150,7 +148,7 @@ template <index_t BlockSize,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
......@@ -158,17 +156,10 @@ template <index_t BlockSize,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsExtraM,
bool BBlockLdsExtraN>
index_t CThreadTransferDstScalarPerVector>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
{
static constexpr auto I0 = Number<0>{};
......@@ -238,8 +229,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0,
static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
......@@ -336,8 +327,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
decltype(b_block_desc_k0_n_k1),
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
MXdlPerWave,
NXdlPerWave,
K1>;
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
......@@ -452,59 +443,63 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
AElementwiseOperation,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
a_element_op);
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
BElementwiseOperation,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
b_element_op);
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -522,8 +517,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
decltype(b_block_desc_k0_n_k1),
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
MXdlPerWave,
NXdlPerWave,
K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......@@ -541,15 +536,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
......@@ -562,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
// preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_even_buf);
......@@ -579,21 +565,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
do
{
// iteration for odd
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(
b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
// gemm even data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
......@@ -602,21 +582,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_odd_buf);
// iteration for even
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(
b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
// gemm odd data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
......@@ -632,17 +606,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
// iteration for odd
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
// gemm even data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
......@@ -689,8 +659,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
......@@ -738,8 +706,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_grid_buf);
}
}
}; // namespace ck
......
......@@ -157,7 +157,6 @@ struct DeviceGemmXdl
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
......@@ -165,25 +164,18 @@ struct DeviceGemmXdl
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks,
decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks,
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks,
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks,
false, // CAccessOrderMRepeatNRepeat,
ABlockLdsAddExtraM,
BBlockLdsAddExtraN>;
CThreadTransferDstScalarPerVector>;
#else
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
......@@ -224,7 +216,7 @@ struct DeviceGemmXdl
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
#endif
// Argument
struct Argument : public BaseArgument
{
......@@ -337,11 +329,12 @@ struct DeviceGemmXdl
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
true,
true>;
......@@ -369,11 +362,12 @@ struct DeviceGemmXdl
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
true,
false>;
......@@ -404,11 +398,12 @@ struct DeviceGemmXdl
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
false,
true>;
......@@ -436,11 +431,12 @@ struct DeviceGemmXdl
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
false,
false>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment