Commit f124c7a1 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Comment fixes

parent 936b1d6c
......@@ -74,15 +74,15 @@ int main()
std::cout << "dims:4,8 strides:1,4" << std::endl;
Print2d(desc_4x8_s1x4);
using Cord0x0Type = ck::Tuple<ck::Number<0>, ck::Number<0>>;
constexpr ck::index_t offset_0x0 = desc_4x8_s1x4.CalculateOffset(Cord0x0Type{});
std::cout << "Constexpr calculated [0, 0] offset:" << offset_0x0 << std::endl;
using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>;
constexpr ck::index_t offset_1x1 = desc_4x8_s1x4.CalculateOffset(Cord1x1Type{});
std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:4,(2,4) strides:2,(1,8)
const auto desc_4x2x4_s2x1x8 =
ck::make_naive_tensor_descriptor(ck::make_tuple(4, 2, 4), ck::make_tuple(2, 1, 8));
// Transform to 2d
// Transform to 2d (column-major, need to to reverse dims)
const auto desc_4x2x4_s2x1x8_merged = ck::transform_tensor_descriptor(
desc_4x2x4_s2x1x8,
ck::make_tuple(ck::make_pass_through_transform(4),
......
......@@ -71,9 +71,9 @@ int main()
const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8);
std::cout << "dims:4,8 strides:1,4" << std::endl;
Print2d(layout_4x8_s1x4);
using Cord0x0Type = ck::Tuple<ck::Number<0>, ck::Number<0>>;
constexpr ck::index_t offset_0x0 = layout_4x8_s1x4.template operator()<Cord0x0Type>();
std::cout << "Constexpr calculated [0, 0] offset:" << offset_0x0 << std::endl;
using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>;
constexpr ck::index_t offset_1x1 = layout_4x8_s1x4.template operator()<Cord1x1Type>();
std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor)
// dims:4,(2,4) strides:2,(1,8)
......
......@@ -62,15 +62,18 @@ struct Layout
Number<Tuple<Ts...>::Size()>{});
}
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
// If tuple is element, then pass through (sequence with one element)
template <typename Idx, typename... Ts>
__host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>& tuple)
__host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>&)
{
if constexpr(Idx::value == 0)
{
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
{
constexpr index_t merge_nelems =
decltype(UnrollNestedTuple(tuple.At(Idx{})))::Size();
constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
using LowerDimsSequence =
typename arithmetic_sequence_gen<0, merge_nelems, 1>::type;
return LowerDimsSequence::Reverse();
......@@ -82,12 +85,12 @@ struct Layout
}
else
{
using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(tuple));
using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
const auto next_seq_val = PreviousSeqT::At(I0) + 1;
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
{
constexpr index_t merge_nelems =
decltype(UnrollNestedTuple(tuple.At(Idx{})))::Size();
constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
using LowerDimsSequence =
typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
type;
......@@ -100,11 +103,13 @@ struct Layout
}
}
// Iterate over nested tuples in shape
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx)
{
if constexpr(!IsTupleNested(Tuple<IdxDims...>{}))
if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
{
// Index unrolled to flatten, return shape
return shape;
......@@ -112,7 +117,7 @@ struct Layout
else
{
// Iterate over shape tuple elements:
// 1. If coressponding idx element is tuple then return (will be unrolled)
// 1. If corresponding idx element is tuple then return (will be unrolled)
// 2. If no, pack in tuple. It will be restored during unroll.
auto unrolled_shape_via_idx = generate_tuple(
[&](auto i) {
......@@ -139,7 +144,7 @@ struct Layout
DescriptorToMerge& desc)
{
// Reverse each element in tuple
using ReversedUnrolledShape = decltype(ReverseTuple(UnrollNestedTuple(shape)));
using ReversedUnrolledShape = decltype(TupleReverse(UnrollNestedTuple(shape)));
const auto merge_elems = ReversedUnrolledShape{};
// Generate reverted indexes (column major traverse)
......@@ -165,7 +170,7 @@ struct Layout
{
// If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major
const auto merge_elems = ReverseTuple(UnrollNestedTuple(shape.At(i)));
const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
return make_merge_transform(merge_elems);
}
else
......
......@@ -132,7 +132,7 @@ __host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
}
template <typename... Ts>
__host__ __device__ constexpr auto ReverseTuple(const Tuple<Ts...>& tuple)
__host__ __device__ constexpr auto TupleReverse(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
......@@ -161,7 +161,7 @@ template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename... Ts>
__host__ __device__ constexpr auto IsTupleNested(const Tuple<Ts...>&)
__host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
{
return (is_detected<is_tuple, Ts>::value || ...);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment