Commit 2cb05d6d authored by Chao Liu's avatar Chao Liu
Browse files

removing use of reference from DynamicTensorCoordinate

parent 3abe105f
...@@ -167,10 +167,7 @@ struct DynamicTensorDescriptor ...@@ -167,10 +167,7 @@ struct DynamicTensorDescriptor
constexpr auto up_dim_ids = UpperDimensionIdss{}.At(itran); constexpr auto up_dim_ids = UpperDimensionIdss{}.At(itran);
// lengths_hidden_pick_up contains a reference to lengths_hidden set_container_subset(hidden_lengths, up_dim_ids, tran.GetUpperLengths());
auto hidden_lengths_pick_up = pick_container_element(hidden_lengths, up_dim_ids);
hidden_lengths_pick_up = tran.GetUpperLengths();
}); });
return hidden_lengths; return hidden_lengths;
...@@ -225,7 +222,7 @@ struct DynamicTensorCoordinateStep ...@@ -225,7 +222,7 @@ struct DynamicTensorCoordinateStep
public: public:
__host__ __device__ explicit constexpr DynamicTensorCoordinateStep( __host__ __device__ explicit constexpr DynamicTensorCoordinateStep(
const VisibleIndex& idx_diff_visible, const Array<bool, NTransform>& do_transforms) const VisibleIndex& idx_diff_visible, const MultiIndex<NTransform>& do_transforms)
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms} : idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
{ {
} }
...@@ -239,7 +236,7 @@ struct DynamicTensorCoordinateStep ...@@ -239,7 +236,7 @@ struct DynamicTensorCoordinateStep
} }
const VisibleIndex idx_diff_visible_; const VisibleIndex idx_diff_visible_;
const Array<bool, NTransform> do_transforms_; const MultiIndex<NTransform> do_transforms_;
}; };
// TODO: How to fix this? It uses an struct instead of lambda because lambda // TODO: How to fix this? It uses an struct instead of lambda because lambda
...@@ -344,8 +341,7 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe ...@@ -344,8 +341,7 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
MultiIndex<ndim_hidden> idx_hidden; MultiIndex<ndim_hidden> idx_hidden;
// initialize visible index // initialize visible index
auto idx_hidden_pick_visible = pick_container_element(idx_hidden, visible_dim_ids); set_container_subset(idx_hidden, visible_dim_ids, idx_visible);
idx_hidden_pick_visible = idx_visible;
// calculate hidden index // calculate hidden index
static_for<ntransform, 0, -1>{}([&tensor_desc, &idx_hidden](auto itran_p1) { static_for<ntransform, 0, -1>{}([&tensor_desc, &idx_hidden](auto itran_p1) {
...@@ -354,13 +350,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe ...@@ -354,13 +350,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto idx_up = pick_container_element(idx_hidden, dims_up); const auto idx_up = get_container_subset(idx_hidden, dims_up);
auto idx_low = pick_container_element(idx_hidden, dims_low);
MultiIndex<dims_low.Size()> idx_low;
tran.CalculateLowerIndex(idx_low, idx_up); tran.CalculateLowerIndex(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
}); });
// better to use std::move?
return DynamicTensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden}; return DynamicTensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
} }
...@@ -376,22 +374,25 @@ make_dynamic_tensor_coordinate_step(const TensorDesc&, const VisibleIndex& idx_d ...@@ -376,22 +374,25 @@ make_dynamic_tensor_coordinate_step(const TensorDesc&, const VisibleIndex& idx_d
constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension(); constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension();
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds();
Array<bool, ntransform> do_transforms{false}; // use index_t for boolean type
auto do_transforms = make_zero_multi_index<ntransform>();
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
Array<bool, ndim_hidden> non_zero_diff{false}; // decide do_transform by checkout non-zero index diff components
MultiIndex<VisibleIndex::Size()> non_zero_diff_pick_visible;
auto non_zero_diff_pick_visible = pick_container_element(non_zero_diff, visible_dim_ids); static_for<0, ndim_visible, 1>{}(
[&](auto i) { non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); });
static_for<0, ndim_visible, 1>{}([&non_zero_diff_pick_visible, &idx_diff_visible](auto i) { set_container_subset(is_non_zero_diff, visible_dim_ids, non_zero_diff_pick_visible);
non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0);
});
static_for<ntransform - 1, -1, -1>{}([&do_transforms, &non_zero_diff](auto itran) { static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto non_zero_diff_pick_up = pick_container_element(non_zero_diff, dims_up); const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
auto non_zero_diff_pick_low = pick_container_element(non_zero_diff, dims_low);
MultiIndex<dims_low.Size()> non_zero_diff_pick_low;
// if any of upper index diff components is non-zero, then // if any of upper index diff components is non-zero, then
// 1) Need to do this transform // 1) Need to do this transform
...@@ -403,9 +404,9 @@ make_dynamic_tensor_coordinate_step(const TensorDesc&, const VisibleIndex& idx_d ...@@ -403,9 +404,9 @@ make_dynamic_tensor_coordinate_step(const TensorDesc&, const VisibleIndex& idx_d
do_transforms(itran) = idx_diff_up_has_non_zero; do_transforms(itran) = idx_diff_up_has_non_zero;
static_for<0, dims_low.Size(), 1>{}( static_for<0, dims_low.Size(), 1>{}(
[&non_zero_diff_pick_low, &idx_diff_up_has_non_zero](auto i) { [&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero;
}); set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
}); });
return DynamicTensorCoordinateStep<ntransform, ndim_visible>{idx_diff_visible, do_transforms}; return DynamicTensorCoordinateStep<ntransform, ndim_visible>{idx_diff_visible, do_transforms};
...@@ -426,20 +427,20 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(const TensorDe ...@@ -426,20 +427,20 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(const TensorDe
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>(); auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize visible index diff // initialize visible index diff
// idx_diff_hidden_pick_visible contains reference to idx_diff_hidden set_container_subset(
auto idx_diff_hidden_pick_visible = idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff());
pick_container_element(idx_diff_hidden, TensorDesc::GetVisibleDimensionIds());
idx_diff_hidden_pick_visible = coord_step.GetVisibleIndexDiff();
// this is what needs to be updated // this is what needs to be updated
auto& idx_hidden = coord.GetHiddenIndex(); auto& idx_hidden = coord.GetHiddenIndex();
// update visible index // update visible index
auto idx_hidden_pick_visible = auto idx_hidden_pick_visible =
pick_container_element(idx_hidden, TensorDesc::GetVisibleDimensionIds()); get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
idx_hidden_pick_visible += coord_step.GetIndexDiff(); idx_hidden_pick_visible += coord_step.GetIndexDiff();
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
// update rest of hidden index // update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(coord_step.do_transforms_[itran]) if(coord_step.do_transforms_[itran])
...@@ -448,17 +449,20 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(const TensorDe ...@@ -448,17 +449,20 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(const TensorDe
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
// this const is for ContainerElementPicker, Array itself may not be const const auto idx_up = get_container_subset(idx_hidden, dims_up);
const auto idx_up = pick_container_element(idx_hidden, dims_up); auto idx_low = get_container_subset(idx_hidden, dims_low);
auto idx_low = pick_container_element(idx_hidden, dims_low); const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
const auto idx_diff_up = pick_container_element(idx_diff_hidden, dims_up); MultiIndex<dims_low.Size()> idx_diff_low;
auto idx_diff_low = pick_container_element(idx_diff_hidden, dims_low);
// calculate idx_diff_low
tran.CalculateLowerIndexDiff(idx_diff_low, idx_diff_up, idx_low, idx_up); tran.CalculateLowerIndexDiff(idx_diff_low, idx_diff_up, idx_low, idx_up);
// update idx_low // update idx_low
idx_low += idx_diff_low; idx_low += idx_diff_low;
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
set_container_subset(idx_hidden, dims_low, idx_low);
} }
}); });
} }
...@@ -481,7 +485,7 @@ coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& te ...@@ -481,7 +485,7 @@ coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& te
if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex()) if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex())
{ {
const auto idx_up = const auto idx_up =
pick_container_element(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran)); get_container_subset(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran));
valid = valid && tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up); valid = valid && tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up);
} }
......
...@@ -194,5 +194,39 @@ __host__ __device__ constexpr auto container_cat(const Container& x) ...@@ -194,5 +194,39 @@ __host__ __device__ constexpr auto container_cat(const Container& x)
return x; return x;
} }
template <typename T, index_t N, index_t... Is>
__host__ __device__ constexpr auto get_container_subset(const Array<T, N>& arr, Sequence<Is...>)
{
static_assert(N >= sizeof...(Is), "wrong! size");
return make_array(arr[Number<Is>{}]...);
}
template <typename... Ts, index_t... Is>
__host__ __device__ constexpr auto get_container_subset(const Tuple<Ts...>& tup, Sequence<Is...>)
{
static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
return make_tuple(tup[Number<Is>{}]...);
}
template <typename T, index_t N, index_t... Is>
__host__ __device__ constexpr void
set_container_subset(Array<T, N>& y, Sequence<Is...> picks, const Array<T, sizeof...(Is)>& x)
{
static_assert(N >= sizeof...(Is), "wrong! size");
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
template <typename... Ys, index_t... Is, typename... Xs>
__host__ __device__ constexpr void
set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>& x)
{
static_assert(sizeof...(Ys) >= sizeof...(Is) && sizeof...(Is) == sizeof...(Xs), "wrong! size");
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment