Commit 619661f8 authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent b3ab0e12
...@@ -15,6 +15,7 @@ namespace ck { ...@@ -15,6 +15,7 @@ namespace ck {
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
typename SrcElementwiseOperation, typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadSliceLengths, typename ThreadSliceLengths,
...@@ -43,14 +44,16 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -43,14 +44,16 @@ struct BlockwiseTensorSliceTransfer_v4
__device__ constexpr BlockwiseTensorSliceTransfer_v4( __device__ constexpr BlockwiseTensorSliceTransfer_v4(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin, const Index& dst_block_slice_origin,
const SrcElementwiseOperation& src_element_op) const DstElementwiseOperation& dst_element_op)
: threadwise_transfer_(src_desc, : threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
src_element_op,
dst_desc, dst_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
src_element_op) dst_element_op)
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() && static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
...@@ -164,6 +167,7 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -164,6 +167,7 @@ struct BlockwiseTensorSliceTransfer_v4
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r2<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v3r2<ThreadSliceLengths,
SrcElementwiseOperation, SrcElementwiseOperation,
DstElementwiseOperation,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
......
...@@ -355,11 +355,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -355,11 +355,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
#if 0
AElementwiseOperation, AElementwiseOperation,
#else
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
#endif
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
...@@ -378,20 +375,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -378,20 +375,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_grid_desc_k0_m_k1, true>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1, a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
a_element_op); ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
#if 0
BElementwiseOperation, BElementwiseOperation,
#else
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
#endif
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_K0_N_K1,
...@@ -410,11 +406,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -410,11 +406,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_k0_n_k1, true>(
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1, b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
b_element_op); ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -461,13 +459,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -461,13 +459,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
} }
// main body // clear C
index_t k0_block_data_begin = 0;
c_thread_buf.Clear(); c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
index_t k0_block_data_begin = 0;
do do
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
...@@ -658,6 +657,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -658,6 +657,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v4< auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v4<
BlockSize, // index_t BlockSize, BlockSize, // index_t BlockSize,
ck::tensor_operation::element_wise::PassThrough, // SrcElementwiseOperation, ck::tensor_operation::element_wise::PassThrough, // SrcElementwiseOperation,
CElementwiseOperation, // DstElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
MRepeatPerShuffle_CCopy, MRepeatPerShuffle_CCopy,
...@@ -694,9 +694,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -694,9 +694,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
false> // bool ThreadTransferDstResetCoordinateAfterRun> false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, {c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
ck::tensor_operation::element_wise::PassThrough{}}; c_element_op};
constexpr auto mrepeat_forward_step = constexpr auto mrepeat_forward_step =
make_multi_index(0, MRepeatPerShuffle_CCopy, 0, 0, 0, 0); make_multi_index(0, MRepeatPerShuffle_CCopy, 0, 0, 0, 0);
......
...@@ -47,6 +47,7 @@ struct lambda_scalar_per_access_for_src_and_dst ...@@ -47,6 +47,7 @@ struct lambda_scalar_per_access_for_src_and_dst
// 4. Use thread buffer // 4. Use thread buffer
template <typename SliceLengths, template <typename SliceLengths,
typename SrcElementwiseOperation, typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename SrcData, typename SrcData,
typename DstData, typename DstData,
...@@ -80,12 +81,14 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -80,12 +81,14 @@ struct ThreadwiseTensorSliceTransfer_v3r2
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r2( __device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_slice_origin, const Index& dst_slice_origin,
const SrcElementwiseOperation& src_element_op) const DstElementwiseOperation& dst_element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op) src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{ {
} }
...@@ -816,6 +819,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -816,6 +819,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
const SrcElementwiseOperation src_element_op_; const SrcElementwiseOperation src_element_op_;
const DstElementwiseOperation dst_element_op_;
}; };
} // namespace ck } // namespace ck
......
...@@ -592,7 +592,7 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -592,7 +592,7 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
return str.str(); return str.str();
} }
}; // namespace device };
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment