Commit e43359fe authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Add comments and remove not needed getters

parent f124c7a1
......@@ -72,6 +72,7 @@ struct Layout
{
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
{
// Return Sequence for the first tuple
constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
using LowerDimsSequence =
......@@ -80,11 +81,13 @@ struct Layout
}
else
{
// Return first element
return Sequence<0>{};
}
}
else
{
// Get previous element using recurence (in compile-time)
using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
const auto next_seq_val = PreviousSeqT::At(I0) + 1;
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
......@@ -105,6 +108,9 @@ struct Layout
// Iterate over nested tuples in shape
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
// Example idx: (1, 1), 1, 1
// Example shape: (2, (2, 2)), 2, (2, 2)
// Unrolled shape: 2, (2, 2), 2, (2, 2)
template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx)
......@@ -157,6 +163,11 @@ struct Layout
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
}
// Merge nested shape dims
// Input desc shape: 2, 2, 2, 2, 2, 2
// Example idx: 1, 1, 1, 1
// Example shape: 2, (2, 2), 2, (2, 2)
// Merged shape: 2, 4, 2, 4
template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto
MakeMerges(const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
......@@ -206,6 +217,10 @@ struct Layout
}
else
{
// Merge nested shape dims
// Example idx: (1, 1), 1, 1
// Example shape: (2, (2, 2)), 2, (2, 2)
// Merged shape: (2, 4), 2, 4
static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
"Idx rank and Shape rank must be the same (except 1d).");
// Unroll while IdxDims is nested
......@@ -268,20 +283,6 @@ struct Layout
}
}
/**
* \brief Returns real offset to element as const in runtime.
*
* \tparam Idxs Tuple of indexes.
* \return Calculated offset as const.
*/
template <typename Idxs>
__host__ __device__ constexpr index_t operator()() const
{
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
}
/**
* \brief Returns real offset to element in runtime.
*
......@@ -289,7 +290,7 @@ struct Layout
* \return Calculated offset.
*/
template <typename Idxs>
__host__ __device__ constexpr index_t operator()()
__host__ __device__ constexpr index_t operator()() const
{
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
......@@ -310,28 +311,6 @@ struct Layout
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
}
/**
* \brief Length getter (product if tuple) as const.
*
* \tparam IDim Tuple of indexes or index.
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t GetLength() const
{
const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
{
const auto unrolled_element = UnrollNestedTuple(elem);
return TupleReduce<I0.value, unrolled_element.Size()>(
[](auto x, auto y) { return x * y; }, unrolled_element);
}
else
{
return elem;
}
}
/**
* \brief Length getter (product if tuple).
*
......@@ -339,7 +318,7 @@ struct Layout
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t GetLength()
__host__ __device__ constexpr index_t GetLength() const
{
const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
......@@ -354,43 +333,18 @@ struct Layout
}
}
/**
* \brief Layout size getter (product of shape) as const.
*
* \return Calculated size.
*/
__host__ __device__ constexpr index_t GetLength() const
{
const auto unrolled_shape = UnrollNestedTuple(shape_);
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
/**
* \brief Layout size getter (product of shape).
*
* \return Calculated size.
*/
__host__ __device__ constexpr index_t GetLength()
__host__ __device__ constexpr index_t GetLength() const
{
const auto unrolled_shape = UnrollNestedTuple(shape_);
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
/**
* \brief Dimension getter as const.
*
* \tparam IDim Dimension idx.
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr auto Get() const
{
const auto elem = shape_.At(Number<IDim>{});
return elem;
}
/**
* \brief Dimension getter.
*
......@@ -398,7 +352,7 @@ struct Layout
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr auto Get()
__host__ __device__ constexpr auto Get() const
{
const auto elem = shape_.At(Number<IDim>{});
return elem;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment