Commit e43359fe authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Add comments and remove not needed getters

parent f124c7a1
...@@ -72,6 +72,7 @@ struct Layout ...@@ -72,6 +72,7 @@ struct Layout
{ {
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value) if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
{ {
// Return Sequence for the first tuple
constexpr index_t merge_nelems = decltype(UnrollNestedTuple( constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size(); tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
using LowerDimsSequence = using LowerDimsSequence =
...@@ -80,11 +81,13 @@ struct Layout ...@@ -80,11 +81,13 @@ struct Layout
} }
else else
{ {
// Return first element
return Sequence<0>{}; return Sequence<0>{};
} }
} }
else else
{ {
// Get previous element using recurence (in compile-time)
using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{})); using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
const auto next_seq_val = PreviousSeqT::At(I0) + 1; const auto next_seq_val = PreviousSeqT::At(I0) + 1;
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value) if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
...@@ -105,6 +108,9 @@ struct Layout ...@@ -105,6 +108,9 @@ struct Layout
// Iterate over nested tuples in shape // Iterate over nested tuples in shape
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...> // Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
// Example idx: (1, 1), 1, 1
// Example shape: (2, (2, 2)), 2, (2, 2)
// Unrolled shape: 2, (2, 2), 2, (2, 2)
template <typename... ShapeDims, typename... IdxDims> template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple<ShapeDims...>& shape, __host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx) const Tuple<IdxDims...>& idx)
...@@ -157,6 +163,11 @@ struct Layout ...@@ -157,6 +163,11 @@ struct Layout
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
} }
// Merge nested shape dims
// Input desc shape: 2, 2, 2, 2, 2, 2
// Example idx: 1, 1, 1, 1
// Example shape: 2, (2, 2), 2, (2, 2)
// Merged shape: 2, 4, 2, 4
template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge> template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto __host__ __device__ constexpr static auto
MakeMerges(const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc) MakeMerges(const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
...@@ -206,6 +217,10 @@ struct Layout ...@@ -206,6 +217,10 @@ struct Layout
} }
else else
{ {
// Merge nested shape dims
// Example idx: (1, 1), 1, 1
// Example shape: (2, (2, 2)), 2, (2, 2)
// Merged shape: (2, 4), 2, 4
static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(), static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
"Idx rank and Shape rank must be the same (except 1d)."); "Idx rank and Shape rank must be the same (except 1d).");
// Unroll while IdxDims is nested // Unroll while IdxDims is nested
...@@ -268,20 +283,6 @@ struct Layout ...@@ -268,20 +283,6 @@ struct Layout
} }
} }
/**
* \brief Returns real offset to element as const in runtime.
*
* \tparam Idxs Tuple of indexes.
* \return Calculated offset as const.
*/
template <typename Idxs>
__host__ __device__ constexpr index_t operator()() const
{
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
}
/** /**
* \brief Returns real offset to element in runtime. * \brief Returns real offset to element in runtime.
* *
...@@ -289,7 +290,7 @@ struct Layout ...@@ -289,7 +290,7 @@ struct Layout
* \return Calculated offset. * \return Calculated offset.
*/ */
template <typename Idxs> template <typename Idxs>
__host__ __device__ constexpr index_t operator()() __host__ __device__ constexpr index_t operator()() const
{ {
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{})); using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{})); using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
...@@ -310,28 +311,6 @@ struct Layout ...@@ -310,28 +311,6 @@ struct Layout
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx)); return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
} }
/**
* \brief Length getter (product if tuple) as const.
*
* \tparam IDim Tuple of indexes or index.
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t GetLength() const
{
const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
{
const auto unrolled_element = UnrollNestedTuple(elem);
return TupleReduce<I0.value, unrolled_element.Size()>(
[](auto x, auto y) { return x * y; }, unrolled_element);
}
else
{
return elem;
}
}
/** /**
* \brief Length getter (product if tuple). * \brief Length getter (product if tuple).
* *
...@@ -339,7 +318,7 @@ struct Layout ...@@ -339,7 +318,7 @@ struct Layout
* \return Calculated size. * \return Calculated size.
*/ */
template <index_t IDim> template <index_t IDim>
__host__ __device__ constexpr index_t GetLength() __host__ __device__ constexpr index_t GetLength() const
{ {
const auto elem = shape_.At(Number<IDim>{}); const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value) if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
...@@ -354,43 +333,18 @@ struct Layout ...@@ -354,43 +333,18 @@ struct Layout
} }
} }
/**
* \brief Layout size getter (product of shape) as const.
*
* \return Calculated size.
*/
__host__ __device__ constexpr index_t GetLength() const
{
const auto unrolled_shape = UnrollNestedTuple(shape_);
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
/** /**
* \brief Layout size getter (product of shape). * \brief Layout size getter (product of shape).
* *
* \return Calculated size. * \return Calculated size.
*/ */
__host__ __device__ constexpr index_t GetLength() __host__ __device__ constexpr index_t GetLength() const
{ {
const auto unrolled_shape = UnrollNestedTuple(shape_); const auto unrolled_shape = UnrollNestedTuple(shape_);
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; }, return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
unrolled_shape); unrolled_shape);
} }
/**
* \brief Dimension getter as const.
*
* \tparam IDim Dimension idx.
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr auto Get() const
{
const auto elem = shape_.At(Number<IDim>{});
return elem;
}
/** /**
* \brief Dimension getter. * \brief Dimension getter.
* *
...@@ -398,7 +352,7 @@ struct Layout ...@@ -398,7 +352,7 @@ struct Layout
* \return Calculated size. * \return Calculated size.
*/ */
template <index_t IDim> template <index_t IDim>
__host__ __device__ constexpr auto Get() __host__ __device__ constexpr auto Get() const
{ {
const auto elem = shape_.At(Number<IDim>{}); const auto elem = shape_.At(Number<IDim>{});
return elem; return elem;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment