Commit f124c7a1 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Comment fixes

parent 936b1d6c
...@@ -74,15 +74,15 @@ int main() ...@@ -74,15 +74,15 @@ int main()
std::cout << "dims:4,8 strides:1,4" << std::endl; std::cout << "dims:4,8 strides:1,4" << std::endl;
Print2d(desc_4x8_s1x4); Print2d(desc_4x8_s1x4);
using Cord0x0Type = ck::Tuple<ck::Number<0>, ck::Number<0>>; using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>;
constexpr ck::index_t offset_0x0 = desc_4x8_s1x4.CalculateOffset(Cord0x0Type{}); constexpr ck::index_t offset_1x1 = desc_4x8_s1x4.CalculateOffset(Cord1x1Type{});
std::cout << "Constexpr calculated [0, 0] offset:" << offset_0x0 << std::endl; std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:4,(2,4) strides:2,(1,8) // dims:4,(2,4) strides:2,(1,8)
const auto desc_4x2x4_s2x1x8 = const auto desc_4x2x4_s2x1x8 =
ck::make_naive_tensor_descriptor(ck::make_tuple(4, 2, 4), ck::make_tuple(2, 1, 8)); ck::make_naive_tensor_descriptor(ck::make_tuple(4, 2, 4), ck::make_tuple(2, 1, 8));
// Transform to 2d // Transform to 2d (column-major, need to to reverse dims)
const auto desc_4x2x4_s2x1x8_merged = ck::transform_tensor_descriptor( const auto desc_4x2x4_s2x1x8_merged = ck::transform_tensor_descriptor(
desc_4x2x4_s2x1x8, desc_4x2x4_s2x1x8,
ck::make_tuple(ck::make_pass_through_transform(4), ck::make_tuple(ck::make_pass_through_transform(4),
......
...@@ -71,9 +71,9 @@ int main() ...@@ -71,9 +71,9 @@ int main()
const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8); const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8);
std::cout << "dims:4,8 strides:1,4" << std::endl; std::cout << "dims:4,8 strides:1,4" << std::endl;
Print2d(layout_4x8_s1x4); Print2d(layout_4x8_s1x4);
using Cord0x0Type = ck::Tuple<ck::Number<0>, ck::Number<0>>; using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>;
constexpr ck::index_t offset_0x0 = layout_4x8_s1x4.template operator()<Cord0x0Type>(); constexpr ck::index_t offset_1x1 = layout_4x8_s1x4.template operator()<Cord1x1Type>();
std::cout << "Constexpr calculated [0, 0] offset:" << offset_0x0 << std::endl; std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor) // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor)
// dims:4,(2,4) strides:2,(1,8) // dims:4,(2,4) strides:2,(1,8)
......
...@@ -62,15 +62,18 @@ struct Layout ...@@ -62,15 +62,18 @@ struct Layout
Number<Tuple<Ts...>::Size()>{}); Number<Tuple<Ts...>::Size()>{});
} }
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
// If tuple is element, then pass through (sequence with one element)
template <typename Idx, typename... Ts> template <typename Idx, typename... Ts>
__host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>& tuple) __host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>&)
{ {
if constexpr(Idx::value == 0) if constexpr(Idx::value == 0)
{ {
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value) if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
{ {
constexpr index_t merge_nelems = constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
decltype(UnrollNestedTuple(tuple.At(Idx{})))::Size(); tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
using LowerDimsSequence = using LowerDimsSequence =
typename arithmetic_sequence_gen<0, merge_nelems, 1>::type; typename arithmetic_sequence_gen<0, merge_nelems, 1>::type;
return LowerDimsSequence::Reverse(); return LowerDimsSequence::Reverse();
...@@ -82,12 +85,12 @@ struct Layout ...@@ -82,12 +85,12 @@ struct Layout
} }
else else
{ {
using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(tuple)); using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
const auto next_seq_val = PreviousSeqT::At(I0) + 1; const auto next_seq_val = PreviousSeqT::At(I0) + 1;
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value) if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
{ {
constexpr index_t merge_nelems = constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
decltype(UnrollNestedTuple(tuple.At(Idx{})))::Size(); tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
using LowerDimsSequence = using LowerDimsSequence =
typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>:: typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
type; type;
...@@ -100,11 +103,13 @@ struct Layout ...@@ -100,11 +103,13 @@ struct Layout
} }
} }
// Iterate over nested tuples in shape
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
template <typename... ShapeDims, typename... IdxDims> template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple<ShapeDims...>& shape, __host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx) const Tuple<IdxDims...>& idx)
{ {
if constexpr(!IsTupleNested(Tuple<IdxDims...>{})) if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
{ {
// Index unrolled to flatten, return shape // Index unrolled to flatten, return shape
return shape; return shape;
...@@ -112,7 +117,7 @@ struct Layout ...@@ -112,7 +117,7 @@ struct Layout
else else
{ {
// Iterate over shape tuple elements: // Iterate over shape tuple elements:
// 1. If coressponding idx element is tuple then return (will be unrolled) // 1. If corresponding idx element is tuple then return (will be unrolled)
// 2. If no, pack in tuple. It will be restored during unroll. // 2. If no, pack in tuple. It will be restored during unroll.
auto unrolled_shape_via_idx = generate_tuple( auto unrolled_shape_via_idx = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -139,7 +144,7 @@ struct Layout ...@@ -139,7 +144,7 @@ struct Layout
DescriptorToMerge& desc) DescriptorToMerge& desc)
{ {
// Reverse each element in tuple // Reverse each element in tuple
using ReversedUnrolledShape = decltype(ReverseTuple(UnrollNestedTuple(shape))); using ReversedUnrolledShape = decltype(TupleReverse(UnrollNestedTuple(shape)));
const auto merge_elems = ReversedUnrolledShape{}; const auto merge_elems = ReversedUnrolledShape{};
// Generate reverted indexes (column major traverse) // Generate reverted indexes (column major traverse)
...@@ -165,7 +170,7 @@ struct Layout ...@@ -165,7 +170,7 @@ struct Layout
{ {
// If shape element is tuple and idx element is Number, then merge // If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major // Unroll and reverse tuple to traverse column-major
const auto merge_elems = ReverseTuple(UnrollNestedTuple(shape.At(i))); const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
return make_merge_transform(merge_elems); return make_merge_transform(merge_elems);
} }
else else
......
...@@ -132,7 +132,7 @@ __host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple) ...@@ -132,7 +132,7 @@ __host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
} }
template <typename... Ts> template <typename... Ts>
__host__ __device__ constexpr auto ReverseTuple(const Tuple<Ts...>& tuple) __host__ __device__ constexpr auto TupleReverse(const Tuple<Ts...>& tuple)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -161,7 +161,7 @@ template <typename T> ...@@ -161,7 +161,7 @@ template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename... Ts> template <typename... Ts>
__host__ __device__ constexpr auto IsTupleNested(const Tuple<Ts...>&) __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
{ {
return (is_detected<is_tuple, Ts>::value || ...); return (is_detected<is_tuple, Ts>::value || ...);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment