Commit 2cb05d6d authored by Chao Liu's avatar Chao Liu
Browse files

removing use of reference from DynamicTensorCoordinate

parent 3abe105f
......@@ -167,10 +167,7 @@ struct DynamicTensorDescriptor
constexpr auto up_dim_ids = UpperDimensionIdss{}.At(itran);
// lengths_hidden_pick_up contains a reference to lengths_hidden
auto hidden_lengths_pick_up = pick_container_element(hidden_lengths, up_dim_ids);
hidden_lengths_pick_up = tran.GetUpperLengths();
set_container_subset(hidden_lengths, up_dim_ids, tran.GetUpperLengths());
});
return hidden_lengths;
......@@ -225,7 +222,7 @@ struct DynamicTensorCoordinateStep
public:
__host__ __device__ explicit constexpr DynamicTensorCoordinateStep(
const VisibleIndex& idx_diff_visible, const Array<bool, NTransform>& do_transforms)
const VisibleIndex& idx_diff_visible, const MultiIndex<NTransform>& do_transforms)
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
{
}
......@@ -239,7 +236,7 @@ struct DynamicTensorCoordinateStep
}
const VisibleIndex idx_diff_visible_;
const Array<bool, NTransform> do_transforms_;
const MultiIndex<NTransform> do_transforms_;
};
// TODO: How to fix this? It uses an struct instead of lambda because lambda
......@@ -344,8 +341,7 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
MultiIndex<ndim_hidden> idx_hidden;
// initialize visible index
auto idx_hidden_pick_visible = pick_container_element(idx_hidden, visible_dim_ids);
idx_hidden_pick_visible = idx_visible;
set_container_subset(idx_hidden, visible_dim_ids, idx_visible);
// calculate hidden index
static_for<ntransform, 0, -1>{}([&tensor_desc, &idx_hidden](auto itran_p1) {
......@@ -354,13 +350,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto idx_up = pick_container_element(idx_hidden, dims_up);
auto idx_low = pick_container_element(idx_hidden, dims_low);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
MultiIndex<dims_low.Size()> idx_low;
tran.CalculateLowerIndex(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
// better to use std::move?
return DynamicTensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
}
......@@ -376,22 +374,25 @@ make_dynamic_tensor_coordinate_step(const TensorDesc&, const VisibleIndex& idx_d
constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension();
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds();
Array<bool, ntransform> do_transforms{false};
// use index_t for boolean type
auto do_transforms = make_zero_multi_index<ntransform>();
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
Array<bool, ndim_hidden> non_zero_diff{false};
// decide do_transform by checkout non-zero index diff components
MultiIndex<VisibleIndex::Size()> non_zero_diff_pick_visible;
auto non_zero_diff_pick_visible = pick_container_element(non_zero_diff, visible_dim_ids);
static_for<0, ndim_visible, 1>{}(
[&](auto i) { non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); });
static_for<0, ndim_visible, 1>{}([&non_zero_diff_pick_visible, &idx_diff_visible](auto i) {
non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0);
});
set_container_subset(is_non_zero_diff, visible_dim_ids, non_zero_diff_pick_visible);
static_for<ntransform - 1, -1, -1>{}([&do_transforms, &non_zero_diff](auto itran) {
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto non_zero_diff_pick_up = pick_container_element(non_zero_diff, dims_up);
auto non_zero_diff_pick_low = pick_container_element(non_zero_diff, dims_low);
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
MultiIndex<dims_low.Size()> non_zero_diff_pick_low;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
......@@ -403,9 +404,9 @@ make_dynamic_tensor_coordinate_step(const TensorDesc&, const VisibleIndex& idx_d
do_transforms(itran) = idx_diff_up_has_non_zero;
static_for<0, dims_low.Size(), 1>{}(
[&non_zero_diff_pick_low, &idx_diff_up_has_non_zero](auto i) {
non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero;
});
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
});
return DynamicTensorCoordinateStep<ntransform, ndim_visible>{idx_diff_visible, do_transforms};
......@@ -426,20 +427,20 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(const TensorDe
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize visible index diff
// idx_diff_hidden_pick_visible contains reference to idx_diff_hidden
auto idx_diff_hidden_pick_visible =
pick_container_element(idx_diff_hidden, TensorDesc::GetVisibleDimensionIds());
idx_diff_hidden_pick_visible = coord_step.GetVisibleIndexDiff();
set_container_subset(
idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff());
// this is what needs to be updated
auto& idx_hidden = coord.GetHiddenIndex();
// update visible index
auto idx_hidden_pick_visible =
pick_container_element(idx_hidden, TensorDesc::GetVisibleDimensionIds());
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
idx_hidden_pick_visible += coord_step.GetIndexDiff();
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
// update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(coord_step.do_transforms_[itran])
......@@ -448,17 +449,20 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(const TensorDe
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
// this const is for ContainerElementPicker, Array itself may not be const
const auto idx_up = pick_container_element(idx_hidden, dims_up);
auto idx_low = pick_container_element(idx_hidden, dims_low);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
auto idx_low = get_container_subset(idx_hidden, dims_low);
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
const auto idx_diff_up = pick_container_element(idx_diff_hidden, dims_up);
auto idx_diff_low = pick_container_element(idx_diff_hidden, dims_low);
MultiIndex<dims_low.Size()> idx_diff_low;
// calculate idx_diff_low
tran.CalculateLowerIndexDiff(idx_diff_low, idx_diff_up, idx_low, idx_up);
// update idx_low
idx_low += idx_diff_low;
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
set_container_subset(idx_hidden, dims_low, idx_low);
}
});
}
......@@ -481,7 +485,7 @@ coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& te
if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex())
{
const auto idx_up =
pick_container_element(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran));
get_container_subset(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran));
valid = valid && tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up);
}
......
......@@ -194,5 +194,39 @@ __host__ __device__ constexpr auto container_cat(const Container& x)
return x;
}
template <typename T, index_t N, index_t... Is>
__host__ __device__ constexpr auto get_container_subset(const Array<T, N>& arr, Sequence<Is...>)
{
static_assert(N >= sizeof...(Is), "wrong! size");
return make_array(arr[Number<Is>{}]...);
}
template <typename... Ts, index_t... Is>
__host__ __device__ constexpr auto get_container_subset(const Tuple<Ts...>& tup, Sequence<Is...>)
{
static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
return make_tuple(tup[Number<Is>{}]...);
}
template <typename T, index_t N, index_t... Is>
__host__ __device__ constexpr void
set_container_subset(Array<T, N>& y, Sequence<Is...> picks, const Array<T, sizeof...(Is)>& x)
{
static_assert(N >= sizeof...(Is), "wrong! size");
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
template <typename... Ys, index_t... Is, typename... Xs>
__host__ __device__ constexpr void
set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>& x)
{
static_assert(sizeof...(Ys) >= sizeof...(Is) && sizeof...(Is) == sizeof...(Xs), "wrong! size");
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment