"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "df76bd32ae684c6cc818f2c3a930f7d60e3eb365"
Commit 1e276c57 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Introduce wrapper for layout

parent e8cddfdc
add_example_executable(example_tensor_transform tensor_transform.cpp)
add_example_executable(example_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
using DataType = int;
template <typename Desc>
void Print(const Desc& desc)
{
for(ck::index_t h = 0; h < desc.GetLength(I0); h++)
{
for(ck::index_t w = 0; w < desc.GetLength(I1); w++)
{
std::cout << desc.CalculateOffset(ck::make_tuple(h, w)) << " ";
}
std::cout << std::endl;
}
}
int main()
{
// Basic descriptor 0, 1, 2, ... 30, 31
// (dims:4,8 strides:1,1)
const auto desc_4x8_s1x1 = ck::make_naive_tensor_descriptor_packed(ck::make_tuple(4, 8));
std::cout << "dims:4,8 strides:1,1" << std::endl;
Print(desc_4x8_s1x1);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31
// dims:4,(4,2) strides:2,(8,1)
const auto desc_4x4x2_s2x8x1 =
ck::make_naive_tensor_descriptor(ck::make_tuple(4, 4, 2), ck::make_tuple(2, 8, 1));
// Transform to 2d
const auto desc_4x4x2_s2x8x1_merged = ck::transform_tensor_descriptor(
desc_4x4x2_s2x8x1,
ck::make_tuple(ck::make_pass_through_transform(4),
ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1, 2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
std::cout << "dims:4,(4,2) strides:2,(8,1)" << std::endl;
Print(desc_4x4x2_s2x8x1_merged);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31
// dims:(2,2),(4,2) strides:(4,1),(8,2)
const auto desc_2x2x4x2_s4x1x8x2 =
ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 4, 2), ck::make_tuple(4, 1, 8, 2));
// Transform to 2d
const auto desc_2x2x4x2_s4x1x8x2_double_merged = ck::transform_tensor_descriptor(
desc_2x2x4x2_s4x1x8x2,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2, 3>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
std::cout << "dims:(2,2),(4,2) strides:(4,1),(8,2)" << std::endl;
Print(desc_2x2x4x2_s4x1x8x2_double_merged);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31
// dims:((2,2),4),2 strides:((4,1),8),2
// Transform to 2d
const auto desc_2x2x4x2_s4x1x8x2_merged = ck::transform_tensor_descriptor(
desc_2x2x4x2_s4x1x8x2,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
ck::make_pass_through_transform(4),
ck::make_pass_through_transform(2)),
ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2>{}, ck::Sequence<3>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
const auto desc_2x2x4x2_s4x1x8x2_nested_merged = ck::transform_tensor_descriptor(
desc_2x2x4x2_s4x1x8x2_merged,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(4, 4)),
ck::make_pass_through_transform(2)),
ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
std::cout << "dims:((2,2),4),2 strides:((4,1),8),2" << std::endl;
Print(desc_2x2x4x2_s4x1x8x2_nested_merged);
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/sequence.hpp"
#include "tensor_transform_wrapper.hpp"
using DataType = int;
template <typename Layout>
void Print(const Layout& layout)
{
for(ck::index_t h = 0; h < ck::tensor_transform_wrapper::size<0>(layout); h++)
{
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++)
{
std::cout << layout(ck::make_tuple(h, w)) << " ";
}
std::cout << std::endl;
}
}
int main()
{
// Basic descriptor 0, 1, 2, ... 30, 31 (runtime descriptor)
// (dims:4,8 strides:1,1)
const auto shape_4x8 = ck::make_tuple(4, 8);
const auto layout_4x8_s1x1 = ck::tensor_transform_wrapper::make_layout(shape_4x8);
std::cout << "dims:4,8 strides:1,1" << std::endl;
Print(layout_4x8_s1x1);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:4,(4,2) strides:2,(8,1)
const auto shape_4x4x2 =
ck::make_tuple(ck::Number<4>{}, ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}));
const auto strides_s2x8x1 =
ck::make_tuple(ck::Number<2>{}, ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}));
const auto layout_4x4x2_s2x8x1 =
ck::tensor_transform_wrapper::make_layout(shape_4x4x2, strides_s2x8x1);
std::cout << "dims:4,(4,2) strides:2,(8,1)" << std::endl;
Print(layout_4x4x2_s2x8x1);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:(2,2),(4,2) strides:((4,1),(8,2)
const auto shape_2x2x4x2 = ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}),
ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}));
const auto strides_s4x1x8x2 = ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<8>{}, ck::Number<2>{}));
static const auto layout_2x2x4x2_s4x1x8x2 =
ck::tensor_transform_wrapper::make_layout(shape_2x2x4x2, strides_s4x1x8x2);
std::cout << "dims:(2,2),(4,2) strides:(4,1),(8,2)" << std::endl;
Print(layout_2x2x4x2_s4x1x8x2);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:((2,2),4),2 strides:((4,1),8),2
// Transform to 2d
const auto shape_2x2x4x2_nested = ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<4>{}),
ck::Number<2>{});
const auto strides_s4x1x8x2_nested = ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<1>{}), ck::Number<8>{}),
ck::Number<2>{});
static const auto layout_2x2x4x2_s4x1x8x2_nested =
ck::tensor_transform_wrapper::make_layout(shape_2x2x4x2_nested, strides_s4x1x8x2_nested);
std::cout << "dims:((2,2),4),2 strides:((4,1),8),2" << std::endl;
Print(layout_2x2x4x2_s4x1x8x2_nested);
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
namespace ck {
namespace tensor_transform_wrapper {
/**
* \brief Layout wrapper
*
* \details
* Layout wrapper that performs the tensor descriptor logic.
*
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes
* (e.g. ((4, 2), 2)), nested dimensions are merged.
* \tparam Strides Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). Stride tuple should be nested if shape tuple is
* nested.
*/
template <typename Shape, typename Strides = Tuple<>>
struct Layout
{
private:
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename Tuple, typename Idx>
constexpr static auto GenerateLowerDim(Tuple tuple)
{
if constexpr(Idx::value == 0)
{
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple>>::value)
{
constexpr index_t merge_nelems =
decltype(UnrollNestedTuple(tuple.At(Idx{})))::Size();
return typename arithmetic_sequence_gen<0, merge_nelems, 1>::type{};
}
else
{
return Sequence<0>{};
}
}
else
{
using PreviousSeqT = decltype(GenerateLowerDim<Tuple, Number<Idx::value - 1>>(tuple));
const auto next_seq_val = PreviousSeqT::At(PreviousSeqT::Size() - 1) + 1;
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple>>::value)
{
constexpr index_t merge_nelems =
decltype(UnrollNestedTuple(tuple.At(Idx{})))::Size();
return typename arithmetic_sequence_gen<next_seq_val,
next_seq_val + merge_nelems,
1>::type{};
}
else
{
return Sequence<next_seq_val>{};
}
}
}
template <typename Tuple, typename Descriptor>
constexpr static auto MakeMerges(const Tuple& tuple, Descriptor& desc)
{
const auto transforms = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple, tuple_element_t<i, Tuple>>::value)
{
const auto merge_elems = UnrollNestedTuple(tuple.At(i));
return make_merge_transform(merge_elems);
}
else
{
return make_pass_through_transform(tuple.At(i));
}
},
Number<Tuple::Size()>{});
const auto lower_dims =
generate_tuple([&](auto i) { return GenerateLowerDim<Tuple, Number<i>>(tuple); },
Number<Tuple::Size()>{});
const auto upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<Tuple::Size()>{});
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
}
template <typename LayoutShape, typename LayoutStrides>
static auto MakeDescriptor(const LayoutShape shape, const LayoutStrides strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
const auto unrolled_strides = UnrollNestedTuple(strides);
if constexpr(ck::is_same_v<LayoutStrides, Tuple<>>)
{
const auto desc = make_naive_tensor_descriptor_packed(unrolled_shape);
return MakeMerges(shape, desc);
}
else
{
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
const auto desc = make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
return MakeMerges(shape, desc);
}
}
public:
using Descriptor = remove_cvref_t<decltype(MakeDescriptor(Shape{}, Strides{}))>;
/**
* \brief Layout constructor.
*
* \param shape Shape for layout.
* \param strides Strides for layout (optional if tensor is packed).
* \return Layout object.
*/
__host__ __device__ Layout() = delete;
__host__ __device__ Layout(const Shape shape, const Strides strides) : descriptor_{}
{
if constexpr(!Descriptor::IsKnownAtCompileTime())
{
descriptor_ = MakeDescriptor(shape, strides);
}
}
__host__ __device__ Layout(const Shape shape) : descriptor_{}
{
if constexpr(!Descriptor::IsKnownAtCompileTime())
{
descriptor_ = MakeDescriptor(shape, Strides{});
}
}
// Returns real offset to element
template <typename Tuple>
__host__ __device__ constexpr index_t operator()(const Tuple Idx) const
{
return descriptor_.CalculateOffset(Idx);
}
template <typename Tuple>
__host__ __device__ constexpr index_t operator()(const Tuple Idx)
{
return descriptor_.CalculateOffset(Idx);
}
// Upper dim getter
template <index_t IDim>
__host__ __device__ constexpr index_t GetLength() const
{
return descriptor_.GetLength(Number<IDim>{});
}
template <index_t IDim>
__host__ __device__ constexpr index_t GetLength()
{
return descriptor_.GetLength(Number<IDim>{});
}
private:
Descriptor descriptor_;
};
// Upper dim getter
template <index_t idx, typename L>
index_t size(L layout)
{
return layout.template GetLength<idx>();
}
template <typename Shape, typename Strides>
Layout<Shape, Strides> make_layout(const Shape& shape, const Strides& strides)
{
return Layout<Shape, Strides>(shape, strides);
}
template <typename Shape>
Layout<Shape> make_layout(const Shape& shape)
{
return Layout<Shape>(shape);
}
} // namespace tensor_transform_wrapper
} // namespace ck
...@@ -33,6 +33,21 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& ...@@ -33,6 +33,21 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
ty); ty);
} }
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{
return unpack2(
[&](auto... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
template <typename... X, typename... Tuples>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuples&... tuples)
{
return concat_tuple(tx, concat_tuple(tuples...));
}
namespace detail { namespace detail {
template <typename F, typename X, index_t... Is> template <typename F, typename X, index_t... Is>
...@@ -78,4 +93,18 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, ...@@ -78,4 +93,18 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
} }
template <typename T>
__host__ __device__ constexpr auto UnrollNestedTuple(const T& element)
{
return make_tuple(element);
}
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element) { return element; }
template <typename... Ts>
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
{
return unpack([&](auto&&... ts) { return concat_tuple(UnrollNestedTuple(ts)...); }, tuple);
}
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment