Commit 5d3b9b73 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent c85e0652
...@@ -146,7 +146,9 @@ struct BlockwiseGenericTensorSliceCopy_v5 ...@@ -146,7 +146,9 @@ struct BlockwiseGenericTensorSliceCopy_v5
BlockDstDesc, BlockDstDesc,
ThreadSliceLengths, ThreadSliceLengths,
SrcDimAccessOrder, SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectoReadDim, SrcVectoReadDim,
DstVectorWriteDim,
SrcDataPerRead, SrcDataPerRead,
DstDataPerWrite, DstDataPerWrite,
SrcAddressSpace, SrcAddressSpace,
......
...@@ -17,8 +17,10 @@ namespace ck { ...@@ -17,8 +17,10 @@ namespace ck {
template <typename SrcDesc, template <typename SrcDesc,
typename DstDesc, typename DstDesc,
typename SliceLengths, typename SliceLengths,
typename SrcDstDimAccessOrder, typename SrcDimAccessOrder,
index_t SrcDstVectorReadWriteDim, typename DstDimAccessOrder,
index_t SrcVectorReadDim,
index_t DstVectorWriteDim,
index_t SrcDataPerRead, index_t SrcDataPerRead,
index_t DstDataPerWrite, index_t DstDataPerWrite,
AddressSpace SrcAddressSpace = AddressSpace::Generic, AddressSpace SrcAddressSpace = AddressSpace::Generic,
...@@ -44,15 +46,19 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -44,15 +46,19 @@ struct ThreadwiseGenericTensorSliceCopy_v5
{ {
static_assert(nDim == SrcDesc::GetNumOfDimension() && static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() && nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() &&
nDim == SrcDstDimAccessOrder::Size(), nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! # of dimensions not the same"); "wrong! # of dimensions not the same");
static_assert(is_valid_sequence_map<SrcDstDimAccessOrder>{}, "wrong! map is not valid"); static_assert(is_valid_sequence_map<SrcDimAccessOrder>{}, "wrong! map is not valid");
static_assert(is_valid_sequence_map<DstDimAccessOrder>{}, "wrong! map is not valid");
static_assert(SliceLengths{}[SrcDstVectorReadWriteDim] % static_assert(
math::lcm(SrcDataPerRead, DstDataPerWrite) == SliceLengths{}[SrcVectorReadDim] % math::lcm(SrcDataPerRead, DstDataPerWrite) == 0,
0, "wrong! cannot evenly divide");
"wrong! cannot evenly divide");
static_assert(
SliceLengths{}[DstVectorWriteDim] % math::lcm(SrcDataPerRead, DstDataPerWrite) == 0,
"wrong! cannot evenly divide");
static_assert(ThreadBufferSize == 4, ""); static_assert(ThreadBufferSize == 4, "");
} }
...@@ -117,19 +123,18 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -117,19 +123,18 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template <typename SrcData> template <typename SrcData>
__device__ void Load(const SrcData* p_src) __device__ void Load(const SrcData* p_src)
{ {
constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{}; constexpr auto vector_access_dim = Number<SrcVectorReadDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerRead>{}; constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
static_assert(SrcDataPerRead == 1 && DstDataPerWrite == 1, ""); static_assert(SrcDataPerRead == 1, "");
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{}; constexpr auto long_vector_size = src_data_per_access;
constexpr auto long_vector_access_lengths = SliceLengths::Modify( constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
static_ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}( static_ford<decltype(long_vector_access_lengths), SrcDimAccessOrder>{}(
[&](auto long_vector_access_id) { [&](auto long_vector_access_id) {
constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify( constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify(
Number<vector_access_dim>{}, Number<vector_access_dim>{},
...@@ -152,19 +157,18 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -152,19 +157,18 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template <typename DstData> template <typename DstData>
__device__ void Store(DstData* p_dst) __device__ void Store(DstData* p_dst)
{ {
constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{}; constexpr auto vector_access_dim = Number<DstVectorWriteDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{}; constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
static_assert(SrcDataPerRead == 1 && DstDataPerWrite == 1, ""); static_assert(DstDataPerWrite == 1, "");
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{}; constexpr auto long_vector_size = dst_data_per_access;
constexpr auto long_vector_access_lengths = SliceLengths::Modify( constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
static_ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}( static_ford<decltype(long_vector_access_lengths), DstDimAccessOrder>{}(
[&](auto long_vector_access_id) { [&](auto long_vector_access_id) {
constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify( constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify(
Number<vector_access_dim>{}, Number<vector_access_dim>{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment