Commit 81b79a77 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Extend functionality

parent 1e276c57
...@@ -15,12 +15,25 @@ ...@@ -15,12 +15,25 @@
static constexpr auto I0 = ck::Number<0>{}; static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{}; static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
using DataType = int; using DataType = int;
template <typename Desc> template <typename Desc>
void Print(const Desc& desc) void Print1d(const Desc& desc)
{ {
std::cout << "Print1d" << std::endl;
for(ck::index_t w = 0; w < desc.GetLength(I0); w++)
{
std::cout << desc.CalculateOffset(ck::make_tuple(w)) << " ";
}
std::cout << std::endl;
}
template <typename Desc>
void Print2d(const Desc& desc)
{
std::cout << "Print2d" << std::endl;
for(ck::index_t h = 0; h < desc.GetLength(I0); h++) for(ck::index_t h = 0; h < desc.GetLength(I0); h++)
{ {
for(ck::index_t w = 0; w < desc.GetLength(I1); w++) for(ck::index_t w = 0; w < desc.GetLength(I1); w++)
...@@ -31,61 +44,107 @@ void Print(const Desc& desc) ...@@ -31,61 +44,107 @@ void Print(const Desc& desc)
} }
} }
template <typename Desc>
void Print3dCustom(const Desc& desc)
{
std::cout << "Print3dCustom" << std::endl;
for(ck::index_t d = 0; d < desc.GetLength(I0); d++)
{
for(ck::index_t h = 0; h < desc.GetLength(I1); h++)
{
for(ck::index_t w = 0; w < desc.GetLength(I2); w++)
{
std::cout << desc.CalculateOffset(ck::make_tuple(d, h, w)) << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}
}
int main() int main()
{ {
// Tensor descriptor traverse in row-major (need to reverse dims)
std::cout << "Note: Tensor descriptor traverse in row-major" << std::endl;
// Basic descriptor 0, 1, 2, ... 30, 31 // Basic descriptor 0, 1, 2, ... 30, 31
// (dims:4,8 strides:1,1) // (dims:4,8 strides:1,4)
const auto desc_4x8_s1x1 = ck::make_naive_tensor_descriptor_packed(ck::make_tuple(4, 8)); const auto desc_4x8_s1x4 =
std::cout << "dims:4,8 strides:1,1" << std::endl; ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}),
Print(desc_4x8_s1x1); ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}));
std::cout << "dims:4,8 strides:1,4" << std::endl;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 Print2d(desc_4x8_s1x4);
// dims:4,(4,2) strides:2,(8,1)
const auto desc_4x4x2_s2x8x1 = using Cord0x0Type = ck::Tuple<ck::Number<0>, ck::Number<0>>;
ck::make_naive_tensor_descriptor(ck::make_tuple(4, 4, 2), ck::make_tuple(2, 8, 1)); constexpr ck::index_t offset_0x0 = desc_4x8_s1x4.CalculateOffset(Cord0x0Type{});
std::cout << "Constexpr calculated [0, 0] offset:" << offset_0x0 << std::endl;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:4,(2,4) strides:2,(1,8)
const auto desc_4x2x4_s2x1x8 =
ck::make_naive_tensor_descriptor(ck::make_tuple(4, 2, 4), ck::make_tuple(2, 1, 8));
// Transform to 2d // Transform to 2d
const auto desc_4x4x2_s2x8x1_merged = ck::transform_tensor_descriptor( const auto desc_4x2x4_s2x1x8_merged = ck::transform_tensor_descriptor(
desc_4x4x2_s2x8x1, desc_4x2x4_s2x1x8,
ck::make_tuple(ck::make_pass_through_transform(4), ck::make_tuple(ck::make_pass_through_transform(4),
ck::make_merge_transform(ck::make_tuple(4, 2))), ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1, 2>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<2, 1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
std::cout << "dims:4,(4,2) strides:2,(8,1)" << std::endl; std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
Print(desc_4x4x2_s2x8x1_merged); Print2d(desc_4x2x4_s2x1x8_merged);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:(2,2),(4,2) strides:(4,1),(8,2) // dims:(2,2),(2,4) strides:((1,4),(2,8)
const auto desc_2x2x4x2_s4x1x8x2 = const auto desc_2x2x2x4_s1x4x2x8 =
ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 4, 2), ck::make_tuple(4, 1, 8, 2)); ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8));
// Transform to 2d // Transform to 2d
const auto desc_2x2x4x2_s4x1x8x2_double_merged = ck::transform_tensor_descriptor( const auto desc_2x2x2x4_s1x4x2x8_double_merged_2d = ck::transform_tensor_descriptor(
desc_2x2x4x2_s4x1x8x2, desc_2x2x2x4_s1x4x2x8,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)), ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
ck::make_merge_transform(ck::make_tuple(4, 2))), ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2, 3>{}), ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
std::cout << "dims:(2,2),(4,2) strides:(4,1),(8,2)" << std::endl; // Transform to 3d
Print(desc_2x2x4x2_s4x1x8x2_double_merged); const auto desc_2x2x2x4_s1x4x2x8_double_merged_3d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8,
ck::make_tuple(ck::make_pass_through_transform(2),
ck::make_pass_through_transform(2),
ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<3, 2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl;
// dims:((2,2),4),2 strides:((4,1),8),2 Print2d(desc_2x2x2x4_s1x4x2x8_double_merged_2d);
Print3dCustom(desc_2x2x2x4_s1x4x2x8_double_merged_3d);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:((2,2),2),4 strides:((1,4),2),8
// Transform to 2d // Transform to 2d
const auto desc_2x2x4x2_s4x1x8x2_merged = ck::transform_tensor_descriptor( const auto desc_2x2x2x4_s1x4x2x8_nested =
desc_2x2x4x2_s4x1x8x2, ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8));
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_3d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8_nested,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)), ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
ck::make_pass_through_transform(4), ck::make_pass_through_transform(2),
ck::make_pass_through_transform(2)), ck::make_pass_through_transform(4)),
ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2>{}, ck::Sequence<3>{}), ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
const auto desc_2x2x4x2_s4x1x8x2_nested_merged = ck::transform_tensor_descriptor( const auto desc_2x2x2x4_s1x4x2x8_nested_merged_1d = ck::transform_tensor_descriptor(
desc_2x2x4x2_s4x1x8x2_merged, desc_2x2x2x4_s1x4x2x8_nested,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(4, 4)), ck::make_tuple(ck::make_merge_transform(ck::make_tuple(4, 2, 2, 2))),
ck::make_pass_through_transform(2)), ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}),
ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2>{}), ck::make_tuple(ck::Sequence<0>{}));
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_2d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8_nested_merged_3d,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 4)),
ck::make_pass_through_transform(4)),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
std::cout << "dims:((2,2),4),2 strides:((4,1),8),2" << std::endl;
Print(desc_2x2x4x2_s4x1x8x2_nested_merged); std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl;
Print1d(desc_2x2x2x4_s1x4x2x8_nested_merged_1d);
Print2d(desc_2x2x2x4_s1x4x2x8_nested_merged_2d);
Print3dCustom(desc_2x2x2x4_s1x4x2x8_nested_merged_3d);
return 0; return 0;
} }
...@@ -14,8 +14,20 @@ ...@@ -14,8 +14,20 @@
using DataType = int; using DataType = int;
template <typename Layout> template <typename Layout>
void Print(const Layout& layout) void Print1d(const Layout& layout)
{ {
std::cout << "Print1d" << std::endl;
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size(layout); w++)
{
std::cout << layout(ck::make_tuple(w)) << " ";
}
std::cout << std::endl;
}
template <typename Layout>
void Print2d(const Layout& layout)
{
std::cout << "Print2d" << std::endl;
for(ck::index_t h = 0; h < ck::tensor_transform_wrapper::size<0>(layout); h++) for(ck::index_t h = 0; h < ck::tensor_transform_wrapper::size<0>(layout); h++)
{ {
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++) for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++)
...@@ -26,53 +38,84 @@ void Print(const Layout& layout) ...@@ -26,53 +38,84 @@ void Print(const Layout& layout)
} }
} }
// Print in (x,y),z pattern
template <typename Layout>
void Print3dCustom(const Layout& layout)
{
std::cout << "Print3dCustom" << std::endl;
for(ck::index_t d = 0;
d < ck::tensor_transform_wrapper::size<0>(ck::tensor_transform_wrapper::get<0>(layout));
d++)
{
for(ck::index_t h = 0;
h < ck::tensor_transform_wrapper::size<1>(ck::tensor_transform_wrapper::get<0>(layout));
h++)
{
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++)
{
std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}
}
int main() int main()
{ {
// Layout traverse in row-major
std::cout << "Note: Layout traverse in column-major" << std::endl;
// Basic descriptor 0, 1, 2, ... 30, 31 (runtime descriptor) // Basic descriptor 0, 1, 2, ... 30, 31 (runtime descriptor)
// (dims:4,8 strides:1,1) // (dims:4,8 strides:1,4)
const auto shape_4x8 = ck::make_tuple(4, 8); const auto shape_4x8 = ck::make_tuple(4, 8);
const auto layout_4x8_s1x1 = ck::tensor_transform_wrapper::make_layout(shape_4x8); const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8);
std::cout << "dims:4,8 strides:1,1" << std::endl; std::cout << "dims:4,8 strides:1,4" << std::endl;
Print(layout_4x8_s1x1); Print2d(layout_4x8_s1x4);
using Cord0x0Type = ck::Tuple<ck::Number<0>, ck::Number<0>>;
constexpr ck::index_t offset_0x0 = layout_4x8_s1x4.template operator()<Cord0x0Type>();
std::cout << "Constexpr calculated [0, 0] offset:" << offset_0x0 << std::endl;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:4,(4,2) strides:2,(8,1) // dims:4,(2,4) strides:2,(1,8)
const auto shape_4x4x2 = const auto shape_4x2x4 =
ck::make_tuple(ck::Number<4>{}, ck::make_tuple(ck::Number<4>{}, ck::Number<2>{})); ck::make_tuple(ck::Number<4>{}, ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}));
const auto strides_s2x8x1 = const auto strides_s2x1x8 =
ck::make_tuple(ck::Number<2>{}, ck::make_tuple(ck::Number<8>{}, ck::Number<1>{})); ck::make_tuple(ck::Number<2>{}, ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}));
const auto layout_4x4x2_s2x8x1 = const auto layout_4x2x4_s2x1x8 =
ck::tensor_transform_wrapper::make_layout(shape_4x4x2, strides_s2x8x1); ck::tensor_transform_wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
std::cout << "dims:4,(4,2) strides:2,(8,1)" << std::endl; std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
Print(layout_4x4x2_s2x8x1); Print2d(layout_4x2x4_s2x1x8);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:(2,2),(4,2) strides:((4,1),(8,2) // dims:(2,2),(2,4) strides:((1,4),(2,8)
const auto shape_2x2x4x2 = ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), const auto shape_2x2x2x4 = ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}),
ck::make_tuple(ck::Number<4>{}, ck::Number<2>{})); ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}));
const auto strides_s4x1x8x2 = ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<1>{}), const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}),
ck::make_tuple(ck::Number<8>{}, ck::Number<2>{})); ck::make_tuple(ck::Number<2>{}, ck::Number<8>{}));
static const auto layout_2x2x4x2_s4x1x8x2 = static const auto layout_2x2x2x4_s1x4x2x8 =
ck::tensor_transform_wrapper::make_layout(shape_2x2x4x2, strides_s4x1x8x2); ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8);
std::cout << "dims:(2,2),(4,2) strides:(4,1),(8,2)" << std::endl; std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl;
Print(layout_2x2x4x2_s4x1x8x2); Print2d(layout_2x2x2x4_s1x4x2x8);
Print3dCustom(layout_2x2x2x4_s1x4x2x8);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:((2,2),4),2 strides:((4,1),8),2 // dims:((2,2),2),4 strides:((1,4),2),8
// Transform to 2d // Transform to 2d
const auto shape_2x2x4x2_nested = ck::make_tuple( const auto shape_2x2x2x4_nested = ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<4>{}), ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<2>{}),
ck::Number<2>{}); ck::Number<4>{});
const auto strides_s4x1x8x2_nested = ck::make_tuple( const auto strides_s1x4x2x8_nested = ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<1>{}), ck::Number<8>{}), ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}),
ck::Number<2>{}); ck::Number<8>{});
static const auto layout_2x2x4x2_s4x1x8x2_nested = static const auto layout_2x2x2x4_s1x4x2x8_nested =
ck::tensor_transform_wrapper::make_layout(shape_2x2x4x2_nested, strides_s4x1x8x2_nested); ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested);
std::cout << "dims:((2,2),4),2 strides:((4,1),8),2" << std::endl; std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl;
Print(layout_2x2x4x2_s4x1x8x2_nested); Print1d(layout_2x2x2x4_s1x4x2x8_nested);
Print2d(layout_2x2x2x4_s1x4x2x8_nested);
Print3dCustom(layout_2x2x2x4_s1x4x2x8_nested);
return 0; return 0;
} }
...@@ -36,19 +36,44 @@ template <typename Shape, typename Strides = Tuple<>> ...@@ -36,19 +36,44 @@ template <typename Shape, typename Strides = Tuple<>>
struct Layout struct Layout
{ {
private: private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
template <typename T> template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename Tuple, typename Idx> // Generate packed (column-major) strides if not passed
constexpr static auto GenerateLowerDim(Tuple tuple) template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
if constexpr(i.value == 0)
{
return I1;
}
else
{
return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
tuple);
}
},
Number<Tuple<Ts...>::Size()>{});
}
template <typename Idx, typename... Ts>
__host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>& tuple)
{ {
if constexpr(Idx::value == 0) if constexpr(Idx::value == 0)
{ {
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple>>::value) if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
{ {
constexpr index_t merge_nelems = constexpr index_t merge_nelems =
decltype(UnrollNestedTuple(tuple.At(Idx{})))::Size(); decltype(UnrollNestedTuple(tuple.At(Idx{})))::Size();
return typename arithmetic_sequence_gen<0, merge_nelems, 1>::type{}; using LowerDimsSequence =
typename arithmetic_sequence_gen<0, merge_nelems, 1>::type;
return LowerDimsSequence::Reverse();
} }
else else
{ {
...@@ -57,15 +82,16 @@ struct Layout ...@@ -57,15 +82,16 @@ struct Layout
} }
else else
{ {
using PreviousSeqT = decltype(GenerateLowerDim<Tuple, Number<Idx::value - 1>>(tuple)); using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(tuple));
const auto next_seq_val = PreviousSeqT::At(PreviousSeqT::Size() - 1) + 1; const auto next_seq_val = PreviousSeqT::At(I0) + 1;
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple>>::value) if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
{ {
constexpr index_t merge_nelems = constexpr index_t merge_nelems =
decltype(UnrollNestedTuple(tuple.At(Idx{})))::Size(); decltype(UnrollNestedTuple(tuple.At(Idx{})))::Size();
return typename arithmetic_sequence_gen<next_seq_val, using LowerDimsSequence =
next_seq_val + merge_nelems, typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
1>::type{}; type;
return LowerDimsSequence::Reverse();
} }
else else
{ {
...@@ -74,54 +100,140 @@ struct Layout ...@@ -74,54 +100,140 @@ struct Layout
} }
} }
template <typename Tuple, typename Descriptor> template <typename... ShapeDims, typename... IdxDims>
constexpr static auto MakeMerges(const Tuple& tuple, Descriptor& desc) __host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx)
{
if constexpr(!IsTupleNested(Tuple<IdxDims...>{}))
{
// Index unrolled to flatten, return shape
return shape;
}
else
{
// Iterate over shape tuple elements:
// 1. If coressponding idx element is tuple then return (will be unrolled)
// 2. If no, pack in tuple. It will be restored during unroll.
auto unrolled_shape_via_idx = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple,
tuple_element_t<i, Tuple<IdxDims...>>>::value)
{
return shape.At(i);
}
else
{
return make_tuple(shape.At(i));
}
},
Number<Tuple<IdxDims...>::Size()>{});
// Unroll and process next step
return UnrollShapeViaIdx(UnrollNestedTuple<0, 1>(unrolled_shape_via_idx),
UnrollNestedTuple<0, 1>(idx));
}
}
template <typename... ShapeDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
DescriptorToMerge& desc)
{
// Reverse each element in tuple
using ReversedUnrolledShape = decltype(ReverseTuple(UnrollNestedTuple(shape)));
const auto merge_elems = ReversedUnrolledShape{};
// Generate reverted indexes (column major traverse)
using MergeElemsSequence =
typename arithmetic_sequence_gen<0, ReversedUnrolledShape::Size(), 1>::type;
const auto lower_dims = make_tuple(MergeElemsSequence::Reverse());
const auto upper_dims = make_tuple(Sequence<0>{});
// Merge to 1d
return transform_tensor_descriptor(
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
}
template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto
MakeMerges(const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
{ {
const auto transforms = generate_tuple( const auto transforms = generate_tuple(
[&](auto i) { [&](auto i) {
if constexpr(is_detected<is_tuple, tuple_element_t<i, Tuple>>::value) // Compare Idx with shape
if constexpr(is_detected<is_tuple,
tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
!is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value)
{ {
const auto merge_elems = UnrollNestedTuple(tuple.At(i)); // If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major
const auto merge_elems = ReverseTuple(UnrollNestedTuple(shape.At(i)));
return make_merge_transform(merge_elems); return make_merge_transform(merge_elems);
} }
else else
{ {
return make_pass_through_transform(tuple.At(i)); // If shape element is integer and idx element is tuple, passed idx is wrong
static_assert(
!(!is_detected<is_tuple, tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value),
"Wrong Idx for layout()");
// If shape element has the same type as idx element, then pass through
return make_pass_through_transform(shape.At(i));
} }
}, },
Number<Tuple::Size()>{}); Number<Tuple<ShapeDims...>::Size()>{});
const auto lower_dims = const auto lower_dims =
generate_tuple([&](auto i) { return GenerateLowerDim<Tuple, Number<i>>(tuple); }, generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
Number<Tuple::Size()>{}); Number<Tuple<ShapeDims...>::Size()>{});
const auto upper_dims = const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<Tuple::Size()>{}); Number<Tuple<ShapeDims...>::Size()>{});
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims); return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
} }
template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr auto TransformDesc(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx) const
{
if constexpr(Tuple<IdxDims...>::Size() == I1)
{
// 1d idx path
return MakeMerge1d(shape, descriptor_);
}
else
{
static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
"Idx rank and Shape rank must be the same (except 1d).");
// Unroll while IdxDims is nested
const auto unrolled_shape_via_idx = UnrollShapeViaIdx(shape, idx);
// Transform correct form of shape
return MakeMerges(unrolled_shape_via_idx, UnrollNestedTuple(idx), descriptor_);
}
}
template <typename LayoutShape, typename LayoutStrides> template <typename LayoutShape, typename LayoutStrides>
static auto MakeDescriptor(const LayoutShape shape, const LayoutStrides strides) __host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{ {
const auto unrolled_shape = UnrollNestedTuple(shape); const auto unrolled_shape = UnrollNestedTuple(shape);
const auto unrolled_strides = UnrollNestedTuple(strides);
if constexpr(ck::is_same_v<LayoutStrides, Tuple<>>) if constexpr(ck::is_same_v<LayoutStrides, Tuple<>>)
{ {
const auto desc = make_naive_tensor_descriptor_packed(unrolled_shape); // If shape is packed
return MakeMerges(shape, desc); const auto column_major_packed_strides =
GenerateColumnMajorPackedStrides(unrolled_shape);
return make_naive_tensor_descriptor(unrolled_shape, column_major_packed_strides);
} }
else else
{ {
const auto unrolled_strides = UnrollNestedTuple(strides);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(), static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent."); "Size of strides and shape are not consistent.");
const auto desc = make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
return MakeMerges(shape, desc);
} }
} }
public: public:
using Descriptor = remove_cvref_t<decltype(MakeDescriptor(Shape{}, Strides{}))>; using NaiveDescriptorType = remove_cvref_t<decltype(MakeNaiveDescriptor(Shape{}, Strides{}))>;
/** /**
* \brief Layout constructor. * \brief Layout constructor.
...@@ -131,67 +243,221 @@ struct Layout ...@@ -131,67 +243,221 @@ struct Layout
* \return Layout object. * \return Layout object.
*/ */
__host__ __device__ Layout() = delete; __host__ __device__ Layout() = delete;
__host__ __device__ Layout(const Shape shape, const Strides strides) : descriptor_{} __host__ __device__ Layout(const Shape& shape, const Strides& strides) : descriptor_{}
{ {
if constexpr(!Descriptor::IsKnownAtCompileTime()) // Construct if runtime mode
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
{ {
descriptor_ = MakeDescriptor(shape, strides); // Keep only shape, strides are not need for transforms
shape_ = shape;
descriptor_ = MakeNaiveDescriptor(shape, strides);
} }
} }
__host__ __device__ Layout(const Shape shape) : descriptor_{} __host__ __device__ Layout(const Shape& shape) : descriptor_{}
{ {
if constexpr(!Descriptor::IsKnownAtCompileTime()) if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
{ {
descriptor_ = MakeDescriptor(shape, Strides{}); shape_ = shape;
descriptor_ = MakeNaiveDescriptor(shape, Strides{});
} }
} }
// Returns real offset to element /**
template <typename Tuple> * \brief Returns real offset to element as const in runtime.
__host__ __device__ constexpr index_t operator()(const Tuple Idx) const *
* \tparam Idxs Tuple of indexes.
* \return Calculated offset as const.
*/
template <typename Idxs>
__host__ __device__ constexpr index_t operator()() const
{
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
}
/**
* \brief Returns real offset to element in runtime.
*
* \tparam Idxs Tuple of indexes.
* \return Calculated offset.
*/
template <typename Idxs>
__host__ __device__ constexpr index_t operator()()
{ {
return descriptor_.CalculateOffset(Idx); using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
} }
template <typename Tuple> /**
__host__ __device__ constexpr index_t operator()(const Tuple Idx) * \brief Returns real offset to element in compile time.
*
* \param Idx Tuple of indexes.
* \return Calculated offset.
*/
template <typename... Ts>
__host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
{ {
return descriptor_.CalculateOffset(Idx); // Static to construct transformed_desc only once
static const auto transformed_desc = TransformDesc(shape_, Idx);
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
} }
// Upper dim getter /**
* \brief Length getter (product if tuple) as const.
*
* \tparam IDim Tuple of indexes or index.
* \return Calculated size.
*/
template <index_t IDim> template <index_t IDim>
__host__ __device__ constexpr index_t GetLength() const __host__ __device__ constexpr index_t GetLength() const
{ {
return descriptor_.GetLength(Number<IDim>{}); const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
{
using UnrolledElement = decltype(UnrollNestedTuple(elem));
return TupleReduce<I0.value, UnrolledElement::Size()>(
[](auto x, auto y) { return x * y; }, UnrolledElement{});
}
else
{
return elem;
}
} }
/**
* \brief Length getter (product if tuple).
*
* \tparam IDim Tuple of indexes or index.
* \return Calculated size.
*/
template <index_t IDim> template <index_t IDim>
__host__ __device__ constexpr index_t GetLength() __host__ __device__ constexpr index_t GetLength()
{ {
return descriptor_.GetLength(Number<IDim>{}); const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
{
using UnrolledElement = decltype(UnrollNestedTuple(elem));
return TupleReduce<I0.value, UnrolledElement::Size()>(
[](auto x, auto y) { return x * y; }, UnrolledElement{});
}
else
{
return elem;
}
}
/**
* \brief Layout size getter (product of shape) as const.
*
* \return Calculated size.
*/
__host__ __device__ constexpr index_t GetLength() const
{
using UnrolledShape = decltype(UnrollNestedTuple(shape_));
return TupleReduce<I0.value, UnrolledShape::Size()>([](auto x, auto y) { return x * y; },
UnrolledShape{});
}
/**
* \brief Layout size getter (product of shape).
*
* \return Calculated size.
*/
__host__ __device__ constexpr index_t GetLength()
{
using UnrolledShape = decltype(UnrollNestedTuple(shape_));
return TupleReduce<I0.value, UnrolledShape::Size()>([](auto x, auto y) { return x * y; },
UnrolledShape{});
}
/**
* \brief Dimension getter as const.
*
* \tparam IDim Dimension idx.
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr auto Get() const
{
const auto elem = shape_.At(Number<IDim>{});
return elem;
}
/**
* \brief Dimension getter.
*
* \tparam IDim Dimension idx.
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr auto Get()
{
const auto elem = shape_.At(Number<IDim>{});
return elem;
} }
private: private:
Descriptor descriptor_; NaiveDescriptorType descriptor_;
Shape shape_;
}; };
// Upper dim getter // Layout helpers
template <index_t idx, typename L> // Length getter (product if tuple)
index_t size(L layout) template <index_t idx, typename Shape, typename Strides>
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
{ {
return layout.template GetLength<idx>(); return layout.template GetLength<idx>();
} }
// Get shape size (product of dims if tuple)
template <typename... ShapeDims>
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
{
using UnrolledShape = decltype(UnrollNestedTuple(shape));
return TupleReduce<0, UnrolledShape::Size()>([](auto x, auto y) { return x * y; },
UnrolledShape{});
}
// Get dim size (could be returned from get function)
template <typename T>
__host__ __device__ T constexpr size(const T& dim)
{
return dim;
}
// Get layout size (product of shapes)
template <typename Shape, typename Strides>
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
{
return layout.GetLength();
}
// Get shape element size
template <index_t idx, typename... ShapeDims>
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
{
return size(shape.At(Number<idx>{}));
}
// Dim getter (tuple if tuple)
template <index_t idx, typename Shape, typename Strides>
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
{
return layout.template Get<idx>();
}
template <typename Shape, typename Strides> template <typename Shape, typename Strides>
Layout<Shape, Strides> make_layout(const Shape& shape, const Strides& strides) __host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& shape,
const Strides& strides)
{ {
return Layout<Shape, Strides>(shape, strides); return Layout<Shape, Strides>(shape, strides);
} }
template <typename Shape> template <typename Shape>
Layout<Shape> make_layout(const Shape& shape) __host__ __device__ constexpr Layout<Shape> make_layout(const Shape& shape)
{ {
return Layout<Shape>(shape); return Layout<Shape>(shape);
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "functional4.hpp" #include "functional4.hpp"
#include "tuple.hpp" #include "tuple.hpp"
#include "is_detected.hpp"
namespace ck { namespace ck {
...@@ -42,6 +43,13 @@ __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tup ...@@ -42,6 +43,13 @@ __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tup
ty); ty);
} }
// Support any number of tuples to concat (also 1)
template <typename... X>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx)
{
return tx;
}
template <typename... X, typename... Tuples> template <typename... X, typename... Tuples>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuples&... tuples) __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuples&... tuples)
{ {
...@@ -93,18 +101,69 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, ...@@ -93,18 +101,69 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
} }
template <typename T> // By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1>
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element)
{
return element;
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
__host__ __device__ constexpr auto UnrollNestedTuple(const T& element) __host__ __device__ constexpr auto UnrollNestedTuple(const T& element)
{ {
return make_tuple(element); return make_tuple(element);
} }
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element) { return element; } template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
{
if constexpr(Depth == MaxDepth)
{
return tuple;
}
else
{
return unpack(
[&](auto&&... ts) {
return concat_tuple(UnrollNestedTuple<Depth + 1, MaxDepth>(ts)...);
},
tuple);
}
}
template <typename... Ts>
__host__ __device__ constexpr auto ReverseTuple(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
using Idx = Number<Tuple<Ts...>::Size() - i - 1>;
return tuple.At(Idx{});
},
Number<Tuple<Ts...>::Size()>{});
}
// Reduce tuple values in specific range using Function
template <index_t Idx, index_t End, typename F, typename... Ts>
__host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
{
static_assert(Idx < End, "Wrong parameters for TupleReduce");
if constexpr(Idx + 1 == End)
{
return tuple.At(Number<Idx>{});
}
else
{
return f(tuple.At(Number<Idx>{}), TupleReduce<Idx + 1, End>(f, tuple));
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename... Ts> template <typename... Ts>
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple) __host__ __device__ constexpr auto IsTupleNested(const Tuple<Ts...>&)
{ {
return unpack([&](auto&&... ts) { return concat_tuple(UnrollNestedTuple(ts)...); }, tuple); return (is_detected<is_tuple, Ts>::value || ...);
} }
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment