Commit f00c1381 authored by Chao Liu's avatar Chao Liu
Browse files

adding logic to judge linear dimension

parent bf7e7d62
...@@ -256,7 +256,7 @@ struct TransformedTensorDescriptor ...@@ -256,7 +256,7 @@ struct TransformedTensorDescriptor
constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{}; constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{};
constexpr auto sorted_up_lengths = constexpr auto sorted_up_lengths =
pick_sequence_elements(mingled_up_lengths, sorted2unsorted_map); pick_sequence_elements_by_ids(mingled_up_lengths, sorted2unsorted_map);
return sorted_up_lengths; return sorted_up_lengths;
} }
...@@ -347,20 +347,60 @@ struct TransformedTensorDescriptor ...@@ -347,20 +347,60 @@ struct TransformedTensorDescriptor
} }
#if 0 #if 0
struct lambda_sequence_logic_or
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
{
// TODO: should use math::logic_or<bool>, after Sequence can take bool
return typename sequence_reduce<math::logic_or<index_t>, Seqs...>::type{};
}
};
struct lambda_1
{
template <typename Transform>
__host__ __device__ constexpr auto operator()(const Transform& tran) const
{
return tran.GetUpperLengths();
}
};
template <index_t IDim>
__host__ __device__ static constexpr bool GetMaskOfLinearDimensions()
{
// create tuple of linear dimension masks, for all transformations
constexpr auto tuple_of_linear_dimension_mask =
transform_tuple(lambda_1, Transforms{});
// reduce tuple of masks into one mask
constexpr auto linear_dimension_mask =
unpack(lambda_sequence_logic_or{}, tuple_of_linear_dimension_mask);
return linear_dimension_mask;
}
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>) __host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
{ {
// not implemented return GetMaskOfLinearDimensions().At(Number<IDim>{});
} }
__host__ __device__ static constexpr auto GetLinearDimensions() __host__ __device__ static constexpr auto GetLinearDimensions()
{ {
// not implemented constexpr auto linear_dimension_mask = GetMaskOfLienarDimensions();
return pick_sequence_elements_by_mask(
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, linear_dimension_mask);
} }
__host__ __device__ static constexpr auto GetNonLinearDimensions() __host__ __device__ static constexpr auto GetNonLinearDimensions()
{ {
// not implemented constexpr auto nonlinear_dimension_mask =
GetMaskOfLienarDimensions().Transform(math::logic_not<index_t>{});
return pick_sequence_elements_by_mask(
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, nonlinear_dimension_mask);
} }
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups() __host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
......
...@@ -311,6 +311,28 @@ struct sequence_reverse<Sequence<I0, I1>> ...@@ -311,6 +311,28 @@ struct sequence_reverse<Sequence<I0, I1>>
using type = Sequence<I1, I0>; using type = Sequence<I1, I0>;
}; };
#if 0
template <typename Reduce, typename Seq, typename... Seqs>
struct sequence_reduce
{
using type = typename sequence_reduce<Reduce,
Seq,
typename sequence_reduce<Reduce, Seqs...>::type>::type;
};
template <typename Reduce, index_t... Xs, index_t... Ys>
struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
{
using type = Sequence<Reduce{}(Xs, Ys)...>;
};
template <typename Reduce, typename Seq>
struct sequence_reduce<Reduce, Seq>
{
using type = Seq;
};
#endif
template <typename Values, typename Ids, typename Compare> template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl struct sequence_sort_impl
{ {
...@@ -728,11 +750,19 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I ...@@ -728,11 +750,19 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
} }
template <typename Seq, index_t... Is> template <typename Seq, index_t... Is>
__host__ __device__ constexpr auto pick_sequence_elements(Seq, Sequence<Is...>) __host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
{ {
return Sequence<Seq::At(Number<Is>{})...>{}; return Sequence<Seq::At(Number<Is>{})...>{};
} }
#if 0
template <typename Seq, typename Mask>
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
{
// not implemented
}
#endif
template <typename Seq, typename Reduce> template <typename Seq, typename Reduce>
struct lambda_accumulate_on_sequence struct lambda_accumulate_on_sequence
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment