Commit 619661f8 authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent b3ab0e12
......@@ -15,6 +15,7 @@ namespace ck {
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadSliceLengths,
......@@ -43,14 +44,16 @@ struct BlockwiseTensorSliceTransfer_v4
__device__ constexpr BlockwiseTensorSliceTransfer_v4(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const SrcElementwiseOperation& src_element_op)
const DstElementwiseOperation& dst_element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
src_element_op,
dst_desc,
make_zero_multi_index<nDim>(),
src_element_op)
dst_element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
......@@ -164,6 +167,7 @@ struct BlockwiseTensorSliceTransfer_v4
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r2<ThreadSliceLengths,
SrcElementwiseOperation,
DstElementwiseOperation,
DstInMemOp,
SrcData,
DstData,
......
......@@ -355,11 +355,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
#if 0
AElementwiseOperation,
#else
ck::tensor_operation::element_wise::PassThrough,
#endif
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
......@@ -378,20 +375,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
a_element_op);
true>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
#if 0
BElementwiseOperation,
#else
ck::tensor_operation::element_wise::PassThrough,
#endif
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
......@@ -410,11 +406,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
b_element_op);
true>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -461,13 +459,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
}
// main body
index_t k0_block_data_begin = 0;
// clear C
c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
......@@ -658,6 +657,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v4<
BlockSize, // index_t BlockSize,
ck::tensor_operation::element_wise::PassThrough, // SrcElementwiseOperation,
CElementwiseOperation, // DstElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
MRepeatPerShuffle_CCopy,
......@@ -694,9 +694,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
make_multi_index(0, 0, 0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
ck::tensor_operation::element_wise::PassThrough{}};
c_element_op};
constexpr auto mrepeat_forward_step =
make_multi_index(0, MRepeatPerShuffle_CCopy, 0, 0, 0, 0);
......
......@@ -47,6 +47,7 @@ struct lambda_scalar_per_access_for_src_and_dst
// 4. Use thread buffer
template <typename SliceLengths,
typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename SrcData,
typename DstData,
......@@ -80,12 +81,14 @@ struct ThreadwiseTensorSliceTransfer_v3r2
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const SrcElementwiseOperation& src_element_op)
const DstElementwiseOperation& dst_element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op)
src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{
}
......@@ -816,6 +819,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
SrcCoord src_coord_;
DstCoord dst_coord_;
const SrcElementwiseOperation src_element_op_;
const DstElementwiseOperation dst_element_op_;
};
} // namespace ck
......
......@@ -592,7 +592,7 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
return str.str();
}
}; // namespace device
};
} // namespace device
} // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment