Commit 4978c9e7 authored by Chao Liu's avatar Chao Liu
Browse files

replacing array with vector for tensor data

parent e38c1b73
......@@ -358,5 +358,322 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
}
};
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t KPerThreadLoop,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N>
struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
{
struct MatrixIndex
{
index_t row;
index_t col;
};
index_t mMyThreadOffsetA;
index_t mMyThreadOffsetB;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1()
{
static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() &&
ThreadMatrixC::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
"wrong! Cannot evenly divide work among\n");
static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] &&
ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1],
"wrong! ThreadMatrixC lengths is wrong");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.row));
mMyThreadOffsetB = BlockMatrixB{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.col));
}
__device__ static constexpr auto GetThreadMatrixCLengths()
{
constexpr auto I1 = Number<1>{};
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
constexpr index_t MRepeat =
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
constexpr index_t NRepeat =
N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster);
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster;
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1ThreadCluster;
index_t level1_n_id = level1_id % NLevel1ThreadCluster;
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0ThreadCluster;
index_t level0_n_id = level0_id % NLevel0ThreadCluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr auto K = a_block_mtx.GetLength(I0);
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert(MRepeat == 2 && NRepeat == 2,
"wrong! inline asm cannot deal with this GEMM config yet");
// thread A, B
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThread>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThread>{}));
// thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}),
make_tuple(Number<MPerThread>{}, Number<1>{}));
constexpr auto b_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
constexpr auto c_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx.GetElementSpaceSize()];
constexpr auto a_thread_copy =
ThreadwiseDynamicTensorSliceTransfer_v2<FloatA,
FloatA,
BlockMatrixA,
decltype(a_thread_mtx),
Sequence<KPerThreadLoop, MPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_M,
AddressSpace::Generic,
AddressSpace::Vgpr,
1,
true>{BlockMatrixA{}, make_tuple()};
constexpr auto b_thread_copy =
ThreadwiseDynamicTensorSliceTransfer_v2<FloatB,
FloatB,
BlockMatrixB,
decltype(b_thread_mtx),
Sequence<KPerThreadLoop, NPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmBDataPerRead_N,
AddressSpace::Generic,
AddressSpace::Vgpr,
1,
true>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA;
const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB;
// read A_sub_0
a_thread_copy.Run(BlockMatrixA{},
p_a_block_off,
a_thread_mtx,
make_tuple(Number<0>{}, Number<0>{}),
p_a_thread);
// read B_sub_0
b_thread_copy.Run(BlockMatrixB{},
p_b_block_off,
b_thread_mtx,
make_tuple(Number<0>{}, Number<0>{}),
p_b_thread);
// read B_sub_1
b_thread_copy.Run(BlockMatrixB{},
p_b_block_off +
b_block_mtx.CalculateOffset(make_tuple(0, NPerLevel1Cluster)),
b_thread_mtx,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
p_b_thread);
// read A_sub_1
a_thread_copy.Run(BlockMatrixA{},
p_a_block_off +
a_block_mtx.CalculateOffset(make_tuple(0, MPerLevel1Cluster)),
a_thread_mtx,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_a_thread);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
// loop over rest of k
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
// read A_sub_0
a_thread_copy.Run(BlockMatrixA{},
p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, 0)),
a_thread_mtx,
make_tuple(Number<0>{}, Number<0>{}),
p_a_thread);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// read B_sub_0
b_thread_copy.Run(BlockMatrixB{},
p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, 0)),
b_thread_mtx,
make_tuple(Number<0>{}, Number<0>{}),
p_b_thread);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread +
c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
// read B_sub_1
b_thread_copy.Run(BlockMatrixB{},
p_b_block_off +
b_block_mtx.CalculateOffset(make_tuple(k, NPerLevel1Cluster)),
b_thread_mtx,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
p_b_thread);
// read A_sub_1
a_thread_copy.Run(BlockMatrixA{},
p_a_block_off +
a_block_mtx.CalculateOffset(make_tuple(k, MPerLevel1Cluster)),
a_thread_mtx,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_a_thread);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t MPerThread = ThreadMatrixC{}.GetLength(I0);
constexpr index_t NPerThread = ThreadMatrixC{}.GetLength(I1);
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
if constexpr(MRepeat == 2 && NRepeat == 2)
{
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
}
else
{
Run_naive(p_a_block, p_b_block, p_c_thread);
}
#else
Run_naive(p_a_block, p_b_block, p_c_thread);
#endif
}
};
} // namespace ck
#endif
......@@ -7,6 +7,15 @@
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t VectorDim, index_t ScalarPerVector>
......@@ -26,15 +35,13 @@ struct lambda_scalar_step_in_vector
return (i == VectorDim) ? 1 : 0;
}
};
} // namespace detail
// Assume:
// 1. src_desc is known at compile-time
// 2. dst_desc is not known at compile-time
// 3. src_slice_origin_idx is known at compile-time and it's 0
// 4. dst_slice_origin_idx is not-known at compile time
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// TODO: support non-zero src_slice_oring_idx
template <typename SrcData,
typename DstData,
......@@ -98,10 +105,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
......@@ -288,7 +295,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
......@@ -368,9 +375,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// 2. dst_desc is known at compile-time
// 3. src_slice_origin_idx is not known at compile-time
// 4. dst_slice_origin_idx is known at compile-time and it's 0
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
template <typename SrcData,
typename DstData,
typename SrcDesc,
......@@ -432,10 +436,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_scalar_step_in_vector =
generate_sequence(lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
......@@ -623,7 +627,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
......@@ -702,12 +706,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// 1. src_desc and dst_desc are not known at compile-time
// 2. src_slice_origin and dst_slice_origin are not known at compile-time,
// 3. Use thread buffer
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// 3. It does not use pointer for VGPR thread buffer
// 4. It calculate offset for thread buffer directly, instead of moving the coordinate
template <typename SliceLengths,
InMemoryDataOperation DstInMemOp,
typename SrcData,
......@@ -777,10 +775,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_scalar_step_in_vector =
generate_sequence(lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
......@@ -955,10 +953,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
......@@ -1142,7 +1140,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
......@@ -1204,7 +1202,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
......@@ -1293,7 +1291,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
......@@ -1322,5 +1319,102 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
DstCoord dst_slice_origin_coord_;
};
// Assume:
// 1. src:
// 1. src_desc is known at compile-time
// 2. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a
// compile-time distance to src_reference_idx
// 3. use #-iterator
// 2. dst:
// 1. dst_desc is known at compile-time
// 2. a reference dst_reference_idx is given at compile-time, dst_slice_origin_idx has a
// compile-time distance to dst_reference_idx
// 3. use direct address calculation (lower of coordinate)
template <
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename DimAccessOrder,
index_t SrcVectorDim,
index_t SrcScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
index_t SrcScalarStrideInVector,
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseDynamicTensorSliceTransfer_v4
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4(const Index& src_ref_idx)
: src_ref_idx_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx))
{
static_assert(SrcDesc::IsKnownAtCompileTime && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
}
template <typename SrcOriginToRefDistance, typename DstOriginToRefDistance>
__device__ void Run(const SrcDesc& const SrcOriginToRefDistance& const SrcData* p_src,
const DstDesc&,
const DstOriginToRefDistance& DstData* p_dst)
{
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// position in slice window
constexpr auto data_to_origin_dist_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access;
// src coordinate
constexpr auto src_data_to_ref_dist_idx =
SrcOriginToRefDistance{} + data_to_origin_dist_idx;
constexpr auto src_data_to_ref_dist_coord_iterator =
make_dynamic_tensor_coordinate_iterator(SrcDesc{}, src_data_to_ref_dist);
auto src_data_coord = src_ref_coord_;
move_dynamic_tensor_coordinate(
src_data_coord, src_data_coord, src_data_to_ref_coord_iterator);
// copy data from src into buffer
StaticBuffer<SrcData, SrcScalarPerVector> src_buf;
using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
src_buf.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
// copy data from buffer into dst
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr auto src_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr index_t dst_offset = dst_desc.CalculateOffset(
to_multi_index(DstOriginToRefDistance{}) + data_to_origin_dist_idx +
i * src_scalar_step_in_vector);
p_dst[Number<dst_offset>{}] = src_buf.template AsType<SrcData>()[i];
});
});
}
private:
SrcCoord src_ref_idx_;
};
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment