Commit ef6933a2 authored by ltqin's avatar ltqin
Browse files

merge develop and remove hacks

parent 7506342c
...@@ -6,9 +6,8 @@ ...@@ -6,9 +6,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -140,9 +139,8 @@ template <index_t BlockSize, ...@@ -140,9 +139,8 @@ template <index_t BlockSize,
index_t MPerXDL, index_t MPerXDL,
index_t NPerXDL, index_t NPerXDL,
index_t K1Value, index_t K1Value,
index_t MRepeat, index_t MXdlPerWave,
index_t NRepeat, index_t NXdlPerWave,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -150,7 +148,7 @@ template <index_t BlockSize, ...@@ -150,7 +148,7 @@ template <index_t BlockSize,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1, index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1, bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
...@@ -158,17 +156,10 @@ template <index_t BlockSize, ...@@ -158,17 +156,10 @@ template <index_t BlockSize,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1, index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector>
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsExtraM,
bool BBlockLdsExtraN>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -238,8 +229,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -238,8 +229,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0, (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
...@@ -336,8 +327,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -336,8 +327,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
MRepeat, MXdlPerWave,
NRepeat, NXdlPerWave,
K1>; K1>;
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
...@@ -452,11 +443,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -452,11 +443,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -472,19 +463,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -472,19 +463,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_grid_desc_k0_m_k1, true>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1, a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
a_element_op); ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -500,11 +493,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -500,11 +493,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_k0_n_k1, true>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1, b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
b_element_op); ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -522,8 +517,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -522,8 +517,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
MRepeat, MXdlPerWave,
NRepeat, NXdlPerWave,
K1>{}; K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
...@@ -541,15 +536,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -541,15 +536,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double, a_block_desc_k0_m_k1.GetElementSpaceSize()); p_a_block_double, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
...@@ -562,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -562,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_even_buf);
...@@ -579,21 +565,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -579,21 +565,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
do do
{ {
// iteration for odd // iteration for odd
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
a_block_slice_copy_step, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
// gemm even data // gemm even data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
...@@ -602,21 +582,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -602,21 +582,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_odd_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_odd_buf);
// iteration for even // iteration for even
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
a_block_slice_copy_step, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
// gemm odd data // gemm odd data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
...@@ -632,17 +606,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -632,17 +606,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
// iteration for odd // iteration for odd
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
a_block_slice_copy_step, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
// gemm even data // gemm even data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
...@@ -689,8 +659,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -689,8 +659,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
...@@ -738,8 +706,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1 ...@@ -738,8 +706,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3r1
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_grid_buf, c_grid_buf);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
} }
} }
}; // namespace ck }; // namespace ck
......
...@@ -157,7 +157,6 @@ struct DeviceGemmXdl ...@@ -157,7 +157,6 @@ struct DeviceGemmXdl
K1, K1,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -165,7 +164,7 @@ struct DeviceGemmXdl ...@@ -165,7 +164,7 @@ struct DeviceGemmXdl
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun, false, // AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1, ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -173,17 +172,10 @@ struct DeviceGemmXdl ...@@ -173,17 +172,10 @@ struct DeviceGemmXdl
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector>;
decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks,
decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks,
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks,
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks,
false, // CAccessOrderMRepeatNRepeat,
ABlockLdsAddExtraM,
BBlockLdsAddExtraN>;
#else #else
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize, BlockSize,
...@@ -224,7 +216,7 @@ struct DeviceGemmXdl ...@@ -224,7 +216,7 @@ struct DeviceGemmXdl
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
#endif
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -337,11 +329,12 @@ struct DeviceGemmXdl ...@@ -337,11 +329,12 @@ struct DeviceGemmXdl
CDataType, CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
true, true,
true>; true>;
...@@ -369,11 +362,12 @@ struct DeviceGemmXdl ...@@ -369,11 +362,12 @@ struct DeviceGemmXdl
CDataType, CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
true, true,
false>; false>;
...@@ -404,11 +398,12 @@ struct DeviceGemmXdl ...@@ -404,11 +398,12 @@ struct DeviceGemmXdl
CDataType, CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
false, false,
true>; true>;
...@@ -436,11 +431,12 @@ struct DeviceGemmXdl ...@@ -436,11 +431,12 @@ struct DeviceGemmXdl
CDataType, CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
false, false,
false>; false>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment