Commit e17b495d authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 8fcf3f1e
......@@ -14,7 +14,7 @@ template <index_t BlockSize,
class DstAccessOrder,
index_t SrcDataPerRead,
index_t DstDataPerRead>
struct BlockwiseTensorSliceCopy_generic_v1
struct BlockwiseGenericTensorSliceCopy_v1
{
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
......@@ -22,7 +22,7 @@ struct BlockwiseTensorSliceCopy_generic_v1
index_t mDstMyThreadOffset;
__device__
BlockwiseTensorSliceCopy_generic_v1(Array<index_t, nDim> src_block_data_multi_id_begin,
BlockwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_block_data_multi_id_begin,
Array<index_t, nDim> dst_block_data_multi_id_begin)
{
// check NDim consistent
......@@ -155,7 +155,7 @@ struct BlockwiseTensorSliceCopy_generic_v1
const index_t clipboard_offset = thread_tensor_desc.GetOffsetFromMultiIndex(
clipboard_data_multi_id_begin); // cannot not constexpr, why?
threadwise_tensor_slice_copy_generic(SrcDesc{},
threadwise_generic_tensor_slice_copy(SrcDesc{},
p_src + src_offset + mSrcMyThreadOffset,
make_zero_array<index_t, nDim>(),
thread_tensor_desc,
......@@ -193,7 +193,7 @@ struct BlockwiseTensorSliceCopy_generic_v1
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(
dst_data_multi_id_begin); // cannot not constexpr, why?
threadwise_tensor_slice_copy_generic(thread_tensor_desc,
threadwise_generic_tensor_slice_copy(thread_tensor_desc,
p_clipboard + clipboard_offset,
make_zero_array<index_t, nDim>(),
DstDesc{},
......
......@@ -474,7 +474,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
map_out_global2thread,
Number<OutThreadCopyDataPerWrite_W>{});
#else
threadwise_tensor_slice_copy_generic(
threadwise_generic_tensor_slice_copy(
out_10d_thread_desc.ReorderGivenNew2Old(map_out_global2thread),
p_out_thread,
make_zero_array<index_t, 10>(),
......
......@@ -423,7 +423,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
map_out_global2thread,
Number<OutThreadCopyDataPerWrite_W>{});
#else
threadwise_tensor_slice_copy_generic(
threadwise_generic_tensor_slice_copy(
out_10d_thread_desc.ReorderGivenNew2Old(map_out_global2thread),
p_out_thread,
make_zero_array<index_t, 10>(),
......
......@@ -3,7 +3,7 @@
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMergedTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_merged_tensor_slice_op.hip.hpp"
#include "blockwise_generic_tensor_slice_op.hip.hpp"
#include "blockwise_gemm.hip.hpp"
#include "threadwise_tensor_slice_op.hip.hpp"
......@@ -123,7 +123,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
const auto blockwise_in_copy = BlockwiseTensorSliceCopy_generic_v1<
const auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1<
BlockSize,
Float,
decltype(in_c_n1_b_n2_global_merged_desc),
......@@ -152,7 +152,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// this copy operator already have blockwise offset built-in
const auto blockwise_wei_copy =
#if 0
BlockwiseTensorSliceCopy_generic_v1<BlockSize,
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
......@@ -318,7 +318,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
threadwise_tensor_slice_copy_generic(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
threadwise_generic_tensor_slice_copy(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
p_out_thread,
{0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc,
......
......@@ -194,7 +194,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
}
template <class Float, class SrcDesc, class DstDesc, class SliceLengths, class DimAccessOrder>
__device__ void threadwise_tensor_slice_copy_generic(
__device__ void threadwise_generic_tensor_slice_copy(
SrcDesc,
const Float* __restrict__ p_src,
Array<index_t, SrcDesc::GetNumOfDimension()> src_multi_id_begin,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment