Commit f00c1381 authored by Chao Liu's avatar Chao Liu
Browse files

adding logic to judge linear dimension

parent bf7e7d62
......@@ -256,7 +256,7 @@ struct TransformedTensorDescriptor
constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{};
constexpr auto sorted_up_lengths =
pick_sequence_elements(mingled_up_lengths, sorted2unsorted_map);
pick_sequence_elements_by_ids(mingled_up_lengths, sorted2unsorted_map);
return sorted_up_lengths;
}
......@@ -347,20 +347,60 @@ struct TransformedTensorDescriptor
}
#if 0
struct lambda_sequence_logic_or
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
{
// TODO: should use math::logic_or<bool>, after Sequence can take bool
return typename sequence_reduce<math::logic_or<index_t>, Seqs...>::type{};
}
};
struct lambda_1
{
template <typename Transform>
__host__ __device__ constexpr auto operator()(const Transform& tran) const
{
return tran.GetUpperLengths();
}
};
template <index_t IDim>
__host__ __device__ static constexpr bool GetMaskOfLinearDimensions()
{
// create tuple of linear dimension masks, for all transformations
constexpr auto tuple_of_linear_dimension_mask =
transform_tuple(lambda_1, Transforms{});
// reduce tuple of masks into one mask
constexpr auto linear_dimension_mask =
unpack(lambda_sequence_logic_or{}, tuple_of_linear_dimension_mask);
return linear_dimension_mask;
}
template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
{
// not implemented
return GetMaskOfLinearDimensions().At(Number<IDim>{});
}
__host__ __device__ static constexpr auto GetLinearDimensions()
{
// not implemented
constexpr auto linear_dimension_mask = GetMaskOfLienarDimensions();
return pick_sequence_elements_by_mask(
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, linear_dimension_mask);
}
__host__ __device__ static constexpr auto GetNonLinearDimensions()
{
// not implemented
constexpr auto nonlinear_dimension_mask =
GetMaskOfLienarDimensions().Transform(math::logic_not<index_t>{});
return pick_sequence_elements_by_mask(
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, nonlinear_dimension_mask);
}
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
......
......@@ -311,6 +311,28 @@ struct sequence_reverse<Sequence<I0, I1>>
using type = Sequence<I1, I0>;
};
#if 0
template <typename Reduce, typename Seq, typename... Seqs>
struct sequence_reduce
{
using type = typename sequence_reduce<Reduce,
Seq,
typename sequence_reduce<Reduce, Seqs...>::type>::type;
};
template <typename Reduce, index_t... Xs, index_t... Ys>
struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
{
using type = Sequence<Reduce{}(Xs, Ys)...>;
};
template <typename Reduce, typename Seq>
struct sequence_reduce<Reduce, Seq>
{
using type = Seq;
};
#endif
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
{
......@@ -728,11 +750,19 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
}
template <typename Seq, index_t... Is>
__host__ __device__ constexpr auto pick_sequence_elements(Seq, Sequence<Is...>)
__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
{
return Sequence<Seq::At(Number<Is>{})...>{};
}
#if 0
template <typename Seq, typename Mask>
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
{
// not implemented
}
#endif
template <typename Seq, typename Reduce>
struct lambda_accumulate_on_sequence
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment