"include/ck/utility/common_header.hpp" did not exist on "6d4450ef155c39af9ede2cd171be40ee06db9939"
Commit 491437a7 authored by Chao Liu's avatar Chao Liu
Browse files

remove coordinate step hack from all gridwise gemm xdl

parent 619661f8
......@@ -156,11 +156,6 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsExtraM,
bool BBlockLdsExtraN>
......@@ -442,6 +437,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
......@@ -460,16 +456,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_grid_desc_k0_m_k1,
true>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
a_element_op);
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
......@@ -488,11 +487,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_k0_n_k1,
true>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
b_element_op);
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -526,15 +527,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
......@@ -542,36 +534,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
}
// main body
index_t k0_block_data_begin = 0;
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
a_blockwise_copy.RunRead(
a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(
b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
......@@ -619,8 +606,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
......@@ -668,8 +653,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_grid_buf);
}
}
};
......
......@@ -19,6 +19,9 @@ template <typename GridwiseGemm,
typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CBlockClusterAdaptor,
bool HasMainKBlockLoop>
__global__ void
......@@ -31,6 +34,9 @@ __global__ void
const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
......@@ -45,6 +51,9 @@ __global__ void
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
a_element_op,
b_element_op,
c_element_op,
c_block_cluster_adaptor);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
......@@ -129,11 +138,6 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsExtraM,
bool BBlockLdsExtraN>
......@@ -371,6 +375,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t k_batch_id = block_work_idx[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
......@@ -448,6 +453,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
......@@ -469,12 +476,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
true>(
a_b_k0_m_k1_grid_desc,
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
a_element_op,
a_b_k0_m_k1_block_desc,
make_multi_index(0, 0, 0, 0));
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
......@@ -496,8 +507,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
true>(
b_b_k0_n_k1_grid_desc,
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
b_element_op,
b_b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0, 0));
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -531,15 +544,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
......@@ -547,33 +551,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// preload data into LDS
{
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
}
// Initialize C
c_thread_buf.Clear();
// main body
index_t k_block_data_begin = 0;
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
a_blockwise_copy.RunRead(
a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(
b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
......@@ -622,8 +624,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
......@@ -648,6 +648,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
CElementwiseOperation,
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
......@@ -664,14 +665,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2])};
n_thread_data_on_grid_idx[I2]),
c_element_op};
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_grid_buf);
}
}
}; // namespace ck
......
......@@ -412,6 +412,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
......@@ -430,16 +431,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_grid_desc_k0_m_k1,
true>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
a_element_op);
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
......@@ -458,11 +462,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_k0_n_k1,
true>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
b_element_op);
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -519,11 +525,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
}
// main body
index_t k0_block_data_begin = 0;
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
......
......@@ -398,6 +398,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
......@@ -416,16 +417,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_grid_desc_k0_m_k1,
true>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
a_element_op);
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
......@@ -444,11 +448,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_k0_n_k1,
true>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
b_element_op);
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -505,11 +511,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
}
// main body
index_t k0_block_data_begin = 0;
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
......
......@@ -459,7 +459,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
}
// clear C
// Initialize C
c_thread_buf.Clear();
// main body
......
......@@ -290,7 +290,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
constexpr index_t ntransform_dst = remove_cvref_t<DstDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
......
......@@ -209,46 +209,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// TODO remove these hacks
static constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: K0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: M
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 2+: K1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: K0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: M
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 2-: K1
static constexpr auto b_k0_n_k1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: K0
Sequence<0, 0, 0, 0, 0>{}, // 1+: N
Sequence<0, 0, 0, 0, 0>{}), // 2+: K1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: K0
Sequence<0, 0, 0, 0, 0>{}, // 1-: N
Sequence<0, 0, 0, 0, 0>{})); // 2-: K1
static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{};
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
......@@ -289,19 +249,16 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder,
7, // CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks,
decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks,
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks,
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks,
false, // CAccessOrderMRepeatNRepeat,
ABlockLdsAddExtraM,
BBlockLdsAddExtraN>;
#if 0
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
#endif
// Argument
struct Argument : public BaseArgument
......@@ -371,8 +328,9 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
InElementwiseOperation in_element_op_;
......@@ -426,11 +384,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
......@@ -457,11 +415,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
......
......@@ -168,45 +168,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// TODO remove these hacks
static constexpr auto a_k0_m_k1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: K0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: M
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: K1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: K0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: M
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: K1
static constexpr auto b_k0_n_k1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
Sequence<0, 0, 0>{}, // 1+: N
Sequence<0, 0, 0>{}), // 2+: K1
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
Sequence<0, 0, 0>{}, // 1-: N
Sequence<0, 0, 0>{})); // 2-: K1
static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
......@@ -247,20 +208,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder,
7, // CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks,
decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks,
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks,
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks,
false, // CAccessOrderMRepeatNRepeat,
ABlockLdsAddExtraM,
BBlockLdsAddExtraN>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
// Argument
struct Argument : public BaseArgument
{
......@@ -332,8 +283,9 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
InElementwiseOperation in_element_op_;
......@@ -391,11 +343,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
......@@ -422,11 +374,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
......
......@@ -145,44 +145,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// TODO remove these hacks
static constexpr auto a_k0_m_k1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
Sequence<0, 0, 0>{}, // 1+: M
Sequence<0, 0, 0>{}), // 2+: K1
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
Sequence<0, 0, 0>{}, // 1-: M
Sequence<0, 0, 0>{})); // 2-: K1
static constexpr auto b_k0_n_k1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
Sequence<0, 0, 0>{}, // 1+: N
Sequence<0, 0, 0>{}), // 2+: K1
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
Sequence<0, 0, 0>{}, // 1-: N
Sequence<0, 0, 0>{})); // 2-: K1
static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
......@@ -223,20 +185,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder,
7, // CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks,
decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks,
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks,
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks,
false, // CAccessOrderMRepeatNRepeat,
ABlockLdsAddExtraM,
BBlockLdsAddExtraN>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
// Argument
struct Argument : public BaseArgument
{
......@@ -309,8 +261,9 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
InElementwiseOperation in_element_op_;
......@@ -369,11 +322,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
......@@ -400,11 +353,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
......
......@@ -132,45 +132,6 @@ struct DeviceGemmXdl
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// TODO remove these hacks
static constexpr auto a_k0_m_k1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
Sequence<0, 0, 0>{}, // 1+: M
Sequence<0, 0, 0>{}), // 2+: K1
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
Sequence<0, 0, 0>{}, // 1-: M
Sequence<0, 0, 0>{})); // 2-: K1
static constexpr auto b_k0_n_k1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
Sequence<0, 0, 0>{}, // 1+: N
Sequence<0, 0, 0>{}), // 2+: K1
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
Sequence<0, 0, 0>{}, // 1-: N
Sequence<0, 0, 0>{})); // 2-: K1
static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
......@@ -211,20 +172,10 @@ struct DeviceGemmXdl
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks,
decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks,
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks,
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks,
false, // CAccessOrderMRepeatNRepeat,
ABlockLdsAddExtraM,
BBlockLdsAddExtraN>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
// Argument
struct Argument : public BaseArgument
{
......@@ -277,8 +228,9 @@ struct DeviceGemmXdl
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
......@@ -332,11 +284,11 @@ struct DeviceGemmXdl
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
......@@ -363,11 +315,11 @@ struct DeviceGemmXdl
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment