Unverified Commit 836b7e55 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Introduce wrapper library (#1071)

* Introduce wrapper library

* Update cmake files

* Revert "Update cmake files"

This reverts commit c27f88b56590c11a88e26d5d0df7aca51a08133d.

* Fix comments
parent f60cd9d7
...@@ -19,6 +19,7 @@ None ...@@ -19,6 +19,7 @@ None
- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) - Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
- Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) - Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799)
- Support for Batched Gemm DL (#732) - Support for Batched Gemm DL (#732)
- Introduce wrapper sublibrary (limited functionality) (#1071)
### Changes ### Changes
- Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) - Changed the grouped convolution API to maintain consistency with other convolution kernels (#817)
......
add_executable(client_tensor_transform tensor_transform.cpp)
target_link_libraries(client_tensor_transform PRIVATE composable_kernel::device_other_operations)
add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
#include "ck/utility/sequence.hpp" #include "ck/utility/sequence.hpp"
#include "tensor_transform_wrapper.hpp" #include "ck/wrapper/layout.hpp"
using DataType = int; using DataType = int;
...@@ -17,7 +17,7 @@ template <typename Layout> ...@@ -17,7 +17,7 @@ template <typename Layout>
void Print1d(const Layout& layout) void Print1d(const Layout& layout)
{ {
std::cout << "Print1d" << std::endl; std::cout << "Print1d" << std::endl;
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size(layout); w++) for(ck::index_t w = 0; w < ck::wrapper::size(layout); w++)
{ {
std::cout << layout(ck::make_tuple(w)) << " "; std::cout << layout(ck::make_tuple(w)) << " ";
} }
...@@ -28,9 +28,9 @@ template <typename Layout> ...@@ -28,9 +28,9 @@ template <typename Layout>
void Print2d(const Layout& layout) void Print2d(const Layout& layout)
{ {
std::cout << "Print2d" << std::endl; std::cout << "Print2d" << std::endl;
for(ck::index_t h = 0; h < ck::tensor_transform_wrapper::size<0>(layout); h++) for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++)
{ {
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++) for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++)
{ {
std::cout << layout(ck::make_tuple(h, w)) << " "; std::cout << layout(ck::make_tuple(h, w)) << " ";
} }
...@@ -43,15 +43,11 @@ template <typename Layout> ...@@ -43,15 +43,11 @@ template <typename Layout>
void Print3dCustom(const Layout& layout) void Print3dCustom(const Layout& layout)
{ {
std::cout << "Print3dCustom" << std::endl; std::cout << "Print3dCustom" << std::endl;
for(ck::index_t d = 0; for(ck::index_t d = 0; d < ck::wrapper::size<0>(ck::wrapper::get<0>(layout)); d++)
d < ck::tensor_transform_wrapper::size<0>(ck::tensor_transform_wrapper::get<0>(layout));
d++)
{ {
for(ck::index_t h = 0; for(ck::index_t h = 0; h < ck::wrapper::size<1>(ck::wrapper::get<0>(layout)); h++)
h < ck::tensor_transform_wrapper::size<1>(ck::tensor_transform_wrapper::get<0>(layout));
h++)
{ {
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++) for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++)
{ {
std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " "; std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " ";
} }
...@@ -68,7 +64,7 @@ int main() ...@@ -68,7 +64,7 @@ int main()
// Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor) // Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor)
// (dims:4,8 strides:1,4) // (dims:4,8 strides:1,4)
const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}); const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{});
const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8); const auto layout_4x8_s1x4 = ck::wrapper::make_layout(shape_4x8);
std::cout << "dims:4,8 strides:1,4" << std::endl; std::cout << "dims:4,8 strides:1,4" << std::endl;
Print2d(layout_4x8_s1x4); Print2d(layout_4x8_s1x4);
using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>; using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>;
...@@ -79,8 +75,7 @@ int main() ...@@ -79,8 +75,7 @@ int main()
// dims:4,(2,4) strides:2,(1,8) // dims:4,(2,4) strides:2,(1,8)
const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4));
const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8));
const auto layout_4x2x4_s2x1x8 = const auto layout_4x2x4_s2x1x8 = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
ck::tensor_transform_wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
Print2d(layout_4x2x4_s2x1x8); Print2d(layout_4x2x4_s2x1x8);
...@@ -92,7 +87,7 @@ int main() ...@@ -92,7 +87,7 @@ int main()
const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}),
ck::make_tuple(ck::Number<2>{}, ck::Number<8>{})); ck::make_tuple(ck::Number<2>{}, ck::Number<8>{}));
static const auto layout_2x2x2x4_s1x4x2x8 = static const auto layout_2x2x2x4_s1x4x2x8 =
ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8); ck::wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8);
std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl; std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl;
Print2d(layout_2x2x2x4_s1x4x2x8); Print2d(layout_2x2x2x4_s1x4x2x8);
...@@ -108,7 +103,7 @@ int main() ...@@ -108,7 +103,7 @@ int main()
ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}), ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}),
ck::Number<8>{}); ck::Number<8>{});
static const auto layout_2x2x2x4_s1x4x2x8_nested = static const auto layout_2x2x2x4_s1x4x2x8_nested =
ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested); ck::wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested);
std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl; std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl;
Print1d(layout_2x2x2x4_s1x4x2x8_nested); Print1d(layout_2x2x2x4_s1x4x2x8_nested);
......
...@@ -778,7 +778,9 @@ WARN_LOGFILE = ...@@ -778,7 +778,9 @@ WARN_LOGFILE =
INPUT = ../../include/ck/tensor_operation/gpu/grid \ INPUT = ../../include/ck/tensor_operation/gpu/grid \
../../include/ck/tensor_operation/gpu/block \ ../../include/ck/tensor_operation/gpu/block \
../../include/ck/tensor_operation/gpu/thread \ ../../include/ck/tensor_operation/gpu/thread \
../../library/include/ck/library/utility ../../library/include/ck/library/utility \
../../include/ck/wrapper
# This tag can be used to specify the character encoding of the source files # This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
......
...@@ -34,6 +34,7 @@ Current CK library are structured into 4 layers: ...@@ -34,6 +34,7 @@ Current CK library are structured into 4 layers:
* "Templated Tile Operators" layer * "Templated Tile Operators" layer
* "Templated Kernel and Invoker" layer * "Templated Kernel and Invoker" layer
* "Instantiated Kernel and Invoker" layer * "Instantiated Kernel and Invoker" layer
* "Wrapper for tensor transform operations"
* "Client API" layer * "Client API" layer
.. image:: data/ck_layer.png .. image:: data/ck_layer.png
...@@ -50,6 +51,7 @@ The following is a list of CK documents in the suggested reading order: ...@@ -50,6 +51,7 @@ The following is a list of CK documents in the suggested reading order:
tutorial_hello_world tutorial_hello_world
dockerhub dockerhub
wrapper
Supported_Primitives_Guide Supported_Primitives_Guide
API_Reference_Guide API_Reference_Guide
Contributors_Guide Contributors_Guide
===============
Wrapper
===============
-------------------------------------
Description
-------------------------------------
.. note::
The wrapper is under development and its functionality is limited.
CK provides a lightweight wrapper for more complex operations implemented in
the library. It allows indexing of nested layouts using a simple interface
(avoiding complex descriptor transformations).
Example:
.. code-block:: c
const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4));
const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8));
const auto layout = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++)
{
for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++)
{
std::cout << layout(ck::make_tuple(h, w)) << " ";
}
std::cout << std::endl;
}
Output::
dims:4,(2,4) strides:2,(1,8)
0 1 8 9 16 17 24 25
2 3 10 11 18 19 26 27
4 5 12 13 20 21 28 29
6 7 14 15 22 23 30 31
-------------------------------------
Layout
-------------------------------------
.. doxygenstruct:: ck::wrapper::Layout
-------------------------------------
Layout helpers
-------------------------------------
.. doxygenfile:: layout_utils.hpp
add_example_executable(example_tensor_transform tensor_transform.cpp)
add_example_executable(example_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
...@@ -166,4 +166,16 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&) ...@@ -166,4 +166,16 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
return (is_detected<is_tuple, Ts>::value || ...); return (is_detected<is_tuple, Ts>::value || ...);
} }
template <index_t depth = 0, typename T>
__host__ __device__ constexpr auto TupleDepth(const T&)
{
return depth;
}
template <index_t depth = 0, typename... Ts>
__host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&)
{
return math::max(TupleDepth<depth + 1>(Ts{})...);
}
} // namespace ck } // namespace ck
...@@ -3,27 +3,13 @@ ...@@ -3,27 +3,13 @@
#pragma once #pragma once
#include "ck/ck.hpp" #include "ck/wrapper/layout_utils.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
namespace ck { namespace ck {
namespace tensor_transform_wrapper { namespace wrapper {
/** /**
* \brief Layout wrapper * \brief Layout wrapper that performs the tensor descriptor logic.
*
* \details
* Layout wrapper that performs the tensor descriptor logic.
* *
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes * (dynamic layout). It is possible to pass nested shapes
...@@ -32,21 +18,19 @@ namespace tensor_transform_wrapper { ...@@ -32,21 +18,19 @@ namespace tensor_transform_wrapper {
* (dynamic layout). Stride tuple should be nested if shape tuple is * (dynamic layout). Stride tuple should be nested if shape tuple is
* nested. * nested.
*/ */
template <typename Shape, typename Strides = Tuple<>> template <typename Shape, typename Strides>
struct Layout struct Layout
{ {
private: private:
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
// Generate packed (column-major) strides if not passed // Generate packed (column-major) strides if not passed
template <typename... Ts> template <typename... Ts>
__host__ __device__ constexpr static auto __host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& tuple) GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
{ {
const auto unrolled_shape = UnrollNestedTuple(shape);
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
if constexpr(i.value == 0) if constexpr(i.value == 0)
...@@ -56,10 +40,10 @@ struct Layout ...@@ -56,10 +40,10 @@ struct Layout
else else
{ {
return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; }, return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
tuple); unrolled_shape);
} }
}, },
Number<Tuple<Ts...>::Size()>{}); Number<decltype(unrolled_shape)::Size()>{});
} }
// Generate LowerDims in Compile-time for MergeTrasform using passed Type // Generate LowerDims in Compile-time for MergeTrasform using passed Type
...@@ -112,7 +96,7 @@ struct Layout ...@@ -112,7 +96,7 @@ struct Layout
// Example shape: (2, (2, 2)), 2, (2, 2) // Example shape: (2, (2, 2)), 2, (2, 2)
// Unrolled shape: 2, (2, 2), 2, (2, 2) // Unrolled shape: 2, (2, 2), 2, (2, 2)
template <typename... ShapeDims, typename... IdxDims> template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple<ShapeDims...>& shape, __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx) const Tuple<IdxDims...>& idx)
{ {
if constexpr(!IsNestedTuple(Tuple<IdxDims...>{})) if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
...@@ -125,7 +109,7 @@ struct Layout ...@@ -125,7 +109,7 @@ struct Layout
// Iterate over shape tuple elements: // Iterate over shape tuple elements:
// 1. If corresponding idx element is tuple then return (will be unrolled) // 1. If corresponding idx element is tuple then return (will be unrolled)
// 2. If no, pack in tuple. It will be restored during unroll. // 2. If no, pack in tuple. It will be restored during unroll.
auto unrolled_shape_via_idx = generate_tuple( auto aligned_shape = generate_tuple(
[&](auto i) { [&](auto i) {
if constexpr(is_detected<is_tuple, if constexpr(is_detected<is_tuple,
tuple_element_t<i, Tuple<IdxDims...>>>::value) tuple_element_t<i, Tuple<IdxDims...>>>::value)
...@@ -140,7 +124,7 @@ struct Layout ...@@ -140,7 +124,7 @@ struct Layout
Number<Tuple<IdxDims...>::Size()>{}); Number<Tuple<IdxDims...>::Size()>{});
// Unroll and process next step // Unroll and process next step
return UnrollShapeViaIdx(UnrollNestedTuple<0, 1>(unrolled_shape_via_idx), return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape),
UnrollNestedTuple<0, 1>(idx)); UnrollNestedTuple<0, 1>(idx));
} }
} }
...@@ -150,12 +134,9 @@ struct Layout ...@@ -150,12 +134,9 @@ struct Layout
DescriptorToMerge& desc) DescriptorToMerge& desc)
{ {
// Reverse each element in tuple // Reverse each element in tuple
using ReversedUnrolledShape = decltype(TupleReverse(UnrollNestedTuple(shape))); const auto merge_elems = TupleReverse(UnrollNestedTuple(shape));
const auto merge_elems = ReversedUnrolledShape{};
// Generate reverted indexes (column major traverse) // Generate reverted indexes (column major traverse)
using MergeElemsSequence = using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type;
typename arithmetic_sequence_gen<0, ReversedUnrolledShape::Size(), 1>::type;
const auto lower_dims = make_tuple(MergeElemsSequence::Reverse()); const auto lower_dims = make_tuple(MergeElemsSequence::Reverse());
const auto upper_dims = make_tuple(Sequence<0>{}); const auto upper_dims = make_tuple(Sequence<0>{});
// Merge to 1d // Merge to 1d
...@@ -163,14 +144,14 @@ struct Layout ...@@ -163,14 +144,14 @@ struct Layout
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
} }
// Merge nested shape dims // Merge nested shape dims. Merge nested shape dims when idx is also nested.
// Input desc shape: 2, 2, 2, 2, 2, 2 // Input desc shape: 2, 2, 2, 2, 2, 2
// Example idx: 1, 1, 1, 1 // Example idx: 1, 1, 1, 1
// Example shape: 2, (2, 2), 2, (2, 2) // Example shape: 2, (2, 2), 2, (2, 2)
// Merged shape: 2, 4, 2, 4 // Merged shape: 2, 4, 2, 4
template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge> template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto __host__ __device__ constexpr static auto CreateMergedDescriptor(
MakeMerges(const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc) const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
{ {
const auto transforms = generate_tuple( const auto transforms = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -224,9 +205,9 @@ struct Layout ...@@ -224,9 +205,9 @@ struct Layout
static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(), static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
"Idx rank and Shape rank must be the same (except 1d)."); "Idx rank and Shape rank must be the same (except 1d).");
// Unroll while IdxDims is nested // Unroll while IdxDims is nested
const auto unrolled_shape_via_idx = UnrollShapeViaIdx(shape, idx); const auto aligned_shape = AlignShapeToIdx(shape, idx);
// Transform correct form of shape // Transform correct form of shape
return MakeMerges(unrolled_shape_via_idx, UnrollNestedTuple(idx), descriptor_); return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), descriptor_);
} }
} }
...@@ -235,25 +216,20 @@ struct Layout ...@@ -235,25 +216,20 @@ struct Layout
const LayoutStrides& strides) const LayoutStrides& strides)
{ {
const auto unrolled_shape = UnrollNestedTuple(shape); const auto unrolled_shape = UnrollNestedTuple(shape);
if constexpr(ck::is_same_v<LayoutStrides, Tuple<>>)
{
// If shape is packed
const auto column_major_packed_strides =
GenerateColumnMajorPackedStrides(unrolled_shape);
return make_naive_tensor_descriptor(unrolled_shape, column_major_packed_strides);
}
else
{
const auto unrolled_strides = UnrollNestedTuple(strides); const auto unrolled_strides = UnrollNestedTuple(strides);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(), static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent."); "Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
} }
}
public: public:
using NaiveDescriptorType = remove_cvref_t<decltype(MakeNaiveDescriptor(Shape{}, Strides{}))>; // If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
using DeducedStrides =
std::conditional_t<is_same_v<Strides, Tuple<>>,
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
Strides>;
using NaiveDescriptorType =
remove_cvref_t<decltype(MakeNaiveDescriptor(Shape{}, DeducedStrides{}))>;
/** /**
* \brief Layout constructor. * \brief Layout constructor.
...@@ -268,9 +244,9 @@ struct Layout ...@@ -268,9 +244,9 @@ struct Layout
// Construct if runtime mode // Construct if runtime mode
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
{ {
// Keep only shape, strides are not need for transforms
shape_ = shape; shape_ = shape;
descriptor_ = MakeNaiveDescriptor(shape, strides); strides_ = strides;
descriptor_ = MakeNaiveDescriptor(shape_, strides_);
} }
} }
...@@ -279,7 +255,8 @@ struct Layout ...@@ -279,7 +255,8 @@ struct Layout
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
{ {
shape_ = shape; shape_ = shape;
descriptor_ = MakeNaiveDescriptor(shape, Strides{}); strides_ = GenerateColumnMajorPackedStrides(shape_);
descriptor_ = MakeNaiveDescriptor(shape_, strides_);
} }
} }
...@@ -338,7 +315,7 @@ struct Layout ...@@ -338,7 +315,7 @@ struct Layout
* *
* \return Calculated size. * \return Calculated size.
*/ */
__host__ __device__ constexpr index_t GetLength() const __host__ __device__ constexpr index_t GetLengths() const
{ {
const auto unrolled_shape = UnrollNestedTuple(shape_); const auto unrolled_shape = UnrollNestedTuple(shape_);
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; }, return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
...@@ -346,80 +323,24 @@ struct Layout ...@@ -346,80 +323,24 @@ struct Layout
} }
/** /**
* \brief Dimension getter. * \brief Shape getter.
* *
* \tparam IDim Dimension idx. * \return Shape.
* \return Calculated size.
*/ */
template <index_t IDim> __host__ __device__ constexpr Shape GetShape() const { return shape_; }
__host__ __device__ constexpr auto Get() const
{ /**
const auto elem = shape_.At(Number<IDim>{}); * \brief Strides getter.
return elem; *
} * \return Strides.
*/
__host__ __device__ constexpr DeducedStrides GetStrides() const { return strides_; }
private: private:
NaiveDescriptorType descriptor_; NaiveDescriptorType descriptor_;
Shape shape_; Shape shape_;
DeducedStrides strides_;
}; };
// Layout helpers } // namespace wrapper
// Length getter (product if tuple)
template <index_t idx, typename Shape, typename Strides>
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
{
return layout.template GetLength<idx>();
}
// Get shape size (product of dims if tuple)
template <typename... ShapeDims>
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
{
using UnrolledShape = decltype(UnrollNestedTuple(shape));
return TupleReduce<0, UnrolledShape::Size()>([](auto x, auto y) { return x * y; },
UnrolledShape{});
}
// Get dim size (could be returned from get function)
template <typename T>
__host__ __device__ T constexpr size(const T& dim)
{
return dim;
}
// Get layout size (product of shapes)
template <typename Shape, typename Strides>
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
{
return layout.GetLength();
}
// Get shape element size
template <index_t idx, typename... ShapeDims>
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
{
return size(shape.At(Number<idx>{}));
}
// Dim getter (tuple if tuple)
template <index_t idx, typename Shape, typename Strides>
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
{
return layout.template Get<idx>();
}
template <typename Shape, typename Strides>
__host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& shape,
const Strides& strides)
{
return Layout<Shape, Strides>(shape, strides);
}
template <typename Shape>
__host__ __device__ constexpr Layout<Shape> make_layout(const Shape& shape)
{
return Layout<Shape>(shape);
}
} // namespace tensor_transform_wrapper
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
namespace ck {
namespace wrapper {
// Disable from doxygen docs generation
/// @cond
// forward declaration
template <typename Shape, typename Strides = Tuple<>>
struct Layout;
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
/// @endcond
// make_*
/**
* \brief Make layout function.
*
* \tparam Shape Shape for layout.
* \tparam Strides Strides for layout.
* \return Constructed layout.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& shape,
const Strides& strides)
{
return Layout<Shape, Strides>(shape, strides);
}
/**
* \brief Make layout function with packed strides
* (column-major).
*
* \tparam Shape Shape for layout.
* \return Constructed layout.
*/
template <typename Shape>
__host__ __device__ constexpr Layout<Shape> make_layout(const Shape& shape)
{
return Layout<Shape>(shape);
}
// Layout helpers
// get
/**
* \brief Get element from tuple (Shape/Strides/Idxs).
*
* \tparam idx Index to lookup.
* \param tuple Tuple to lookup.
* \return Requsted element.
*/
template <index_t idx, typename... Dims>
__host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
{
return tuple.At(Number<idx>{});
}
/**
* \brief Get sub layout.
*
* \tparam idx Index to lookup.
* \param layout Layout to create sub layout.
* \return Requsted sub layout.
*/
template <index_t idx, typename Shape, typename Strides>
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
{
const auto new_shape = get<idx>(layout.GetShape());
static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
"Shape of sub layout must be tuple");
if constexpr(is_same_v<Strides, Tuple<>>)
{
// If stride not passed, create without strides
return make_layout(new_shape);
}
else
{
const auto new_strides = get<idx>(layout.GetStrides());
static_assert(is_detected<is_tuple, decltype(new_strides)>::value,
"Strides of sub layout must be tuple");
return make_layout(new_shape, new_strides);
}
}
/**
* \brief Hierarchical get.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted element.
*/
template <index_t Idx, index_t... Idxs, typename T>
__host__ __device__ constexpr auto get(const T& elem)
{
return get<Idxs...>(get<Idx>(elem));
}
// size
/**
* \brief Length get (product if tuple).
*
* \tparam idx Index to lookup.
* \param layout Layout to get Shape.
* \return Requsted length.
*/
template <index_t idx, typename Shape, typename Strides>
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
{
return layout.template GetLength<idx>();
}
/**
* \brief Shape size (product of dims).
*
* \param shape Shape to lookup.
* \return Requsted size.
*/
template <typename... ShapeDims>
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
// Get dim size (could be returned from get function)
/**
* \private
*/
template <typename T>
__host__ __device__ T constexpr size(const T& dim)
{
return dim;
}
/**
* \brief Layout size (product of dims).
*
* \param layout Layout to calculate shape size.
* \return Requsted size.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
{
return layout.GetLengths();
}
/**
* \brief Length get from tuple (product if tuple).
*
* \tparam idx Index to lookup.
* \param tuple Tuple to lookup.
* \return Requsted length.
*/
template <index_t idx, typename... Ts>
__host__ __device__ constexpr index_t size(const Tuple<Ts...>& tuple)
{
return size(tuple.At(Number<idx>{}));
}
/**
* \brief Hierarchical size.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted element.
*/
template <index_t... Idxs, typename T>
__host__ __device__ constexpr auto size(const T& elem)
{
return size(get<Idxs...>(elem));
}
// rank
/**
* \brief Get layout rank (num elements in shape).
*
* \param layout Layout to calculate rank.
* \return Requsted rank.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto rank([[maybe_unused]] const Layout<Shape, Strides>& layout)
{
return Shape::Size();
}
/**
* \brief Get tuple rank (num elements in tuple).
* Return 1 if scalar passed.
*
* \param tuple Tuple to calculate rank.
* \return Requsted rank.
*/
template <typename... Dims>
__host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& tuple)
{
return Tuple<Dims...>::Size();
}
/**
* \private
*/
template <index_t IDim>
__host__ __device__ constexpr index_t rank(const Number<IDim>&)
{
return 1;
}
/**
* \private
*/
__host__ __device__ constexpr index_t rank(const index_t&) { return 1; }
/**
* \brief Hierarchical rank.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted rank.
*/
template <index_t... Idxs, typename T>
__host__ __device__ constexpr auto rank(const T& elem)
{
return rank(get<Idxs...>(elem));
}
// depth
/**
* \brief Get depth of the layout shape (return 0 if scalar).
*
* \param layout Layout to calculate depth.
* \return Requsted depth.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto depth(const Layout<Shape, Strides>& layout)
{
return TupleDepth(layout.GetShape());
}
/**
* \brief Get depth of the tuple. (return 0 if scalar)
*
* \param tuple Tuple to calculate depth.
* \return Requsted depth.
*/
template <typename... Dims>
__host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
{
return TupleDepth(tuple);
}
/**
* \private
*/
template <index_t IDim>
__host__ __device__ constexpr index_t depth(const Number<IDim>&)
{
return 0;
}
/**
* \private
*/
__host__ __device__ constexpr index_t depth(const index_t&) { return 0; }
/**
* \brief Hierarchical depth.
*
* \tparam Idxs Indexes to lookup.
* \param elem Element to lookup.
* \return Requsted depth.
*/
template <index_t... Idxs, typename T>
__host__ __device__ constexpr auto depth(const T& elem)
{
return depth(get<Idxs...>(elem));
}
/**
* \brief Get Layout strides.
*
* \param layout Layout to get strides.
* \return Requsted strides.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto stride(const Layout<Shape, Strides>& layout)
{
return layout.GetStrides();
}
/**
* \brief Get Layout shape.
*
* \param layout Layout to get shape.
* \return Requsted shape.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto shape(const Layout<Shape, Strides>& layout)
{
return layout.GetShape();
}
} // namespace wrapper
} // namespace ck
...@@ -149,6 +149,7 @@ add_subdirectory(batched_gemm_multi_d) ...@@ -149,6 +149,7 @@ add_subdirectory(batched_gemm_multi_d)
add_subdirectory(grouped_convnd_bwd_data) add_subdirectory(grouped_convnd_bwd_data)
add_subdirectory(conv_tensor_rearrange) add_subdirectory(conv_tensor_rearrange)
add_subdirectory(transpose) add_subdirectory(transpose)
add_subdirectory(wrapper)
if(GPU_TARGETS MATCHES "gfx11") if(GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op) add_subdirectory(wmma_op)
endif() endif()
add_gtest_executable(test_layout test_layout.cpp)
target_link_libraries(test_layout PRIVATE utility)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment