Commit 65f984f0 authored by Jing Zhang's avatar Jing Zhang
Browse files

move element_op into RunRead

parent 5d73dd3e
...@@ -141,6 +141,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -141,6 +141,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
// loop over space-filling curve // loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) { static_for<0, num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>(); auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto dst_vectors = generate_vectors<DstDatas, DstScalarPerVector>();
// copy data from src_bufs into src_vectors // copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) { static_for<0, nSrc, 1>{}([&](auto i) {
...@@ -150,11 +151,46 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -150,11 +151,46 @@ struct ThreadwiseTensorSliceTransfer_v7r2
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
src_coords_[i]); src_coords_[i]);
src_vectors_tuple_(iAccess)(i).template AsType<src_vector_t>()(I0) = src_vectors(i).template AsType<src_vector_t>()(I0) =
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(),
is_src_valid); is_src_valid);
}); });
// apply pointwise function
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
return src_vectors[iSrc].template AsType<SrcData>()[i];
},
Number<nSrc>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
return dst_vectors(iDst).template AsType<DstData>()(i);
},
Number<nDst>{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2(element_op_, dst_data_refs, src_data_refs);
});
dst_vectors_tuple_(iAccess) = dst_vectors;
// move coordinate // move coordinate
if constexpr(iAccess.value != num_access - 1) if constexpr(iAccess.value != num_access - 1)
{ {
...@@ -190,41 +226,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -190,41 +226,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
{ {
// loop over space-filling curve // loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) { static_for<0, num_access, 1>{}([&](auto iAccess) {
auto src_vectors = src_vectors_tuple_[iAccess]; auto dst_vectors = dst_vectors_tuple_[iAccess];
auto dst_vectors = generate_vectors<DstDatas, DstScalarPerVector>();
// apply pointwise function
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
return src_vectors[iSrc].template AsType<SrcData>()[i];
},
Number<nSrc>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
return dst_vectors(iDst).template AsType<DstData>()(i);
},
Number<nDst>{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2(element_op_, dst_data_refs, src_data_refs);
});
// copy data from buf_vectors into dst_bufs // copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) { static_for<0, nDst, 1>{}([&](auto i) {
...@@ -352,7 +354,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -352,7 +354,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
static constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess(); static constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess();
StaticallyIndexedArray<SrcVectorsType, num_access> src_vectors_tuple_; StaticallyIndexedArray<DstVectorsType, num_access> dst_vectors_tuple_;
SrcCoords src_coords_; SrcCoords src_coords_;
DstCoords dst_coords_; DstCoords dst_coords_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment