Commit 81b79a77 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Extend functionality

parent 1e276c57
......@@ -15,12 +15,25 @@
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
using DataType = int;
template <typename Desc>
void Print(const Desc& desc)
void Print1d(const Desc& desc)
{
std::cout << "Print1d" << std::endl;
for(ck::index_t w = 0; w < desc.GetLength(I0); w++)
{
std::cout << desc.CalculateOffset(ck::make_tuple(w)) << " ";
}
std::cout << std::endl;
}
template <typename Desc>
void Print2d(const Desc& desc)
{
std::cout << "Print2d" << std::endl;
for(ck::index_t h = 0; h < desc.GetLength(I0); h++)
{
for(ck::index_t w = 0; w < desc.GetLength(I1); w++)
......@@ -31,61 +44,107 @@ void Print(const Desc& desc)
}
}
template <typename Desc>
void Print3dCustom(const Desc& desc)
{
std::cout << "Print3dCustom" << std::endl;
for(ck::index_t d = 0; d < desc.GetLength(I0); d++)
{
for(ck::index_t h = 0; h < desc.GetLength(I1); h++)
{
for(ck::index_t w = 0; w < desc.GetLength(I2); w++)
{
std::cout << desc.CalculateOffset(ck::make_tuple(d, h, w)) << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}
}
int main()
{
// Tensor descriptor traverse in row-major (need to reverse dims)
std::cout << "Note: Tensor descriptor traverse in row-major" << std::endl;
// Basic descriptor 0, 1, 2, ... 30, 31
// (dims:4,8 strides:1,1)
const auto desc_4x8_s1x1 = ck::make_naive_tensor_descriptor_packed(ck::make_tuple(4, 8));
std::cout << "dims:4,8 strides:1,1" << std::endl;
Print(desc_4x8_s1x1);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31
// dims:4,(4,2) strides:2,(8,1)
const auto desc_4x4x2_s2x8x1 =
ck::make_naive_tensor_descriptor(ck::make_tuple(4, 4, 2), ck::make_tuple(2, 8, 1));
// (dims:4,8 strides:1,4)
const auto desc_4x8_s1x4 =
ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}));
std::cout << "dims:4,8 strides:1,4" << std::endl;
Print2d(desc_4x8_s1x4);
using Cord0x0Type = ck::Tuple<ck::Number<0>, ck::Number<0>>;
constexpr ck::index_t offset_0x0 = desc_4x8_s1x4.CalculateOffset(Cord0x0Type{});
std::cout << "Constexpr calculated [0, 0] offset:" << offset_0x0 << std::endl;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:4,(2,4) strides:2,(1,8)
const auto desc_4x2x4_s2x1x8 =
ck::make_naive_tensor_descriptor(ck::make_tuple(4, 2, 4), ck::make_tuple(2, 1, 8));
// Transform to 2d
const auto desc_4x4x2_s2x8x1_merged = ck::transform_tensor_descriptor(
desc_4x4x2_s2x8x1,
const auto desc_4x2x4_s2x1x8_merged = ck::transform_tensor_descriptor(
desc_4x2x4_s2x1x8,
ck::make_tuple(ck::make_pass_through_transform(4),
ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1, 2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<2, 1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
std::cout << "dims:4,(4,2) strides:2,(8,1)" << std::endl;
Print(desc_4x4x2_s2x8x1_merged);
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
Print2d(desc_4x2x4_s2x1x8_merged);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31
// dims:(2,2),(4,2) strides:(4,1),(8,2)
const auto desc_2x2x4x2_s4x1x8x2 =
ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 4, 2), ck::make_tuple(4, 1, 8, 2));
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:(2,2),(2,4) strides:((1,4),(2,8)
const auto desc_2x2x2x4_s1x4x2x8 =
ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8));
// Transform to 2d
const auto desc_2x2x4x2_s4x1x8x2_double_merged = ck::transform_tensor_descriptor(
desc_2x2x4x2_s4x1x8x2,
const auto desc_2x2x2x4_s1x4x2x8_double_merged_2d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2, 3>{}),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
std::cout << "dims:(2,2),(4,2) strides:(4,1),(8,2)" << std::endl;
Print(desc_2x2x4x2_s4x1x8x2_double_merged);
// Transform to 3d
const auto desc_2x2x2x4_s1x4x2x8_double_merged_3d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8,
ck::make_tuple(ck::make_pass_through_transform(2),
ck::make_pass_through_transform(2),
ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<3, 2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31
// dims:((2,2),4),2 strides:((4,1),8),2
std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl;
Print2d(desc_2x2x2x4_s1x4x2x8_double_merged_2d);
Print3dCustom(desc_2x2x2x4_s1x4x2x8_double_merged_3d);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:((2,2),2),4 strides:((1,4),2),8
// Transform to 2d
const auto desc_2x2x4x2_s4x1x8x2_merged = ck::transform_tensor_descriptor(
desc_2x2x4x2_s4x1x8x2,
const auto desc_2x2x2x4_s1x4x2x8_nested =
ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8));
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_3d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8_nested,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
ck::make_pass_through_transform(4),
ck::make_pass_through_transform(2)),
ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2>{}, ck::Sequence<3>{}),
ck::make_pass_through_transform(2),
ck::make_pass_through_transform(4)),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
const auto desc_2x2x4x2_s4x1x8x2_nested_merged = ck::transform_tensor_descriptor(
desc_2x2x4x2_s4x1x8x2_merged,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(4, 4)),
ck::make_pass_through_transform(2)),
ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2>{}),
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_1d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8_nested,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(4, 2, 2, 2))),
ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}),
ck::make_tuple(ck::Sequence<0>{}));
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_2d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8_nested_merged_3d,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 4)),
ck::make_pass_through_transform(4)),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
std::cout << "dims:((2,2),4),2 strides:((4,1),8),2" << std::endl;
Print(desc_2x2x4x2_s4x1x8x2_nested_merged);
std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl;
Print1d(desc_2x2x2x4_s1x4x2x8_nested_merged_1d);
Print2d(desc_2x2x2x4_s1x4x2x8_nested_merged_2d);
Print3dCustom(desc_2x2x2x4_s1x4x2x8_nested_merged_3d);
return 0;
}
......@@ -14,8 +14,20 @@
using DataType = int;
template <typename Layout>
void Print(const Layout& layout)
void Print1d(const Layout& layout)
{
std::cout << "Print1d" << std::endl;
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size(layout); w++)
{
std::cout << layout(ck::make_tuple(w)) << " ";
}
std::cout << std::endl;
}
template <typename Layout>
void Print2d(const Layout& layout)
{
std::cout << "Print2d" << std::endl;
for(ck::index_t h = 0; h < ck::tensor_transform_wrapper::size<0>(layout); h++)
{
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++)
......@@ -26,53 +38,84 @@ void Print(const Layout& layout)
}
}
// Print in (x,y),z pattern
template <typename Layout>
void Print3dCustom(const Layout& layout)
{
std::cout << "Print3dCustom" << std::endl;
for(ck::index_t d = 0;
d < ck::tensor_transform_wrapper::size<0>(ck::tensor_transform_wrapper::get<0>(layout));
d++)
{
for(ck::index_t h = 0;
h < ck::tensor_transform_wrapper::size<1>(ck::tensor_transform_wrapper::get<0>(layout));
h++)
{
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++)
{
std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}
}
int main()
{
// Layout traverse in row-major
std::cout << "Note: Layout traverse in column-major" << std::endl;
// Basic descriptor 0, 1, 2, ... 30, 31 (runtime descriptor)
// (dims:4,8 strides:1,1)
// (dims:4,8 strides:1,4)
const auto shape_4x8 = ck::make_tuple(4, 8);
const auto layout_4x8_s1x1 = ck::tensor_transform_wrapper::make_layout(shape_4x8);
std::cout << "dims:4,8 strides:1,1" << std::endl;
Print(layout_4x8_s1x1);
const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8);
std::cout << "dims:4,8 strides:1,4" << std::endl;
Print2d(layout_4x8_s1x4);
using Cord0x0Type = ck::Tuple<ck::Number<0>, ck::Number<0>>;
constexpr ck::index_t offset_0x0 = layout_4x8_s1x4.template operator()<Cord0x0Type>();
std::cout << "Constexpr calculated [0, 0] offset:" << offset_0x0 << std::endl;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:4,(4,2) strides:2,(8,1)
const auto shape_4x4x2 =
ck::make_tuple(ck::Number<4>{}, ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}));
const auto strides_s2x8x1 =
ck::make_tuple(ck::Number<2>{}, ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}));
const auto layout_4x4x2_s2x8x1 =
ck::tensor_transform_wrapper::make_layout(shape_4x4x2, strides_s2x8x1);
// dims:4,(2,4) strides:2,(1,8)
const auto shape_4x2x4 =
ck::make_tuple(ck::Number<4>{}, ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}));
const auto strides_s2x1x8 =
ck::make_tuple(ck::Number<2>{}, ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}));
const auto layout_4x2x4_s2x1x8 =
ck::tensor_transform_wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
std::cout << "dims:4,(4,2) strides:2,(8,1)" << std::endl;
Print(layout_4x4x2_s2x8x1);
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
Print2d(layout_4x2x4_s2x1x8);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:(2,2),(4,2) strides:((4,1),(8,2)
const auto shape_2x2x4x2 = ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}),
ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}));
const auto strides_s4x1x8x2 = ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<8>{}, ck::Number<2>{}));
static const auto layout_2x2x4x2_s4x1x8x2 =
ck::tensor_transform_wrapper::make_layout(shape_2x2x4x2, strides_s4x1x8x2);
// dims:(2,2),(2,4) strides:((1,4),(2,8)
const auto shape_2x2x2x4 = ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}),
ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}));
const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}),
ck::make_tuple(ck::Number<2>{}, ck::Number<8>{}));
static const auto layout_2x2x2x4_s1x4x2x8 =
ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8);
std::cout << "dims:(2,2),(4,2) strides:(4,1),(8,2)" << std::endl;
Print(layout_2x2x4x2_s4x1x8x2);
std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl;
Print2d(layout_2x2x2x4_s1x4x2x8);
Print3dCustom(layout_2x2x2x4_s1x4x2x8);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:((2,2),4),2 strides:((4,1),8),2
// dims:((2,2),2),4 strides:((1,4),2),8
// Transform to 2d
const auto shape_2x2x4x2_nested = ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<4>{}),
ck::Number<2>{});
const auto strides_s4x1x8x2_nested = ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<1>{}), ck::Number<8>{}),
ck::Number<2>{});
static const auto layout_2x2x4x2_s4x1x8x2_nested =
ck::tensor_transform_wrapper::make_layout(shape_2x2x4x2_nested, strides_s4x1x8x2_nested);
std::cout << "dims:((2,2),4),2 strides:((4,1),8),2" << std::endl;
Print(layout_2x2x4x2_s4x1x8x2_nested);
const auto shape_2x2x2x4_nested = ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<2>{}),
ck::Number<4>{});
const auto strides_s1x4x2x8_nested = ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}),
ck::Number<8>{});
static const auto layout_2x2x2x4_s1x4x2x8_nested =
ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested);
std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl;
Print1d(layout_2x2x2x4_s1x4x2x8_nested);
Print2d(layout_2x2x2x4_s1x4x2x8_nested);
Print3dCustom(layout_2x2x2x4_s1x4x2x8_nested);
return 0;
}
......@@ -36,19 +36,44 @@ template <typename Shape, typename Strides = Tuple<>>
struct Layout
{
private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename Tuple, typename Idx>
constexpr static auto GenerateLowerDim(Tuple tuple)
// Generate packed (column-major) strides if not passed
template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
if constexpr(i.value == 0)
{
return I1;
}
else
{
return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
tuple);
}
},
Number<Tuple<Ts...>::Size()>{});
}
template <typename Idx, typename... Ts>
__host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>& tuple)
{
if constexpr(Idx::value == 0)
{
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple>>::value)
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
{
constexpr index_t merge_nelems =
decltype(UnrollNestedTuple(tuple.At(Idx{})))::Size();
return typename arithmetic_sequence_gen<0, merge_nelems, 1>::type{};
using LowerDimsSequence =
typename arithmetic_sequence_gen<0, merge_nelems, 1>::type;
return LowerDimsSequence::Reverse();
}
else
{
......@@ -57,15 +82,16 @@ struct Layout
}
else
{
using PreviousSeqT = decltype(GenerateLowerDim<Tuple, Number<Idx::value - 1>>(tuple));
const auto next_seq_val = PreviousSeqT::At(PreviousSeqT::Size() - 1) + 1;
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple>>::value)
using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(tuple));
const auto next_seq_val = PreviousSeqT::At(I0) + 1;
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
{
constexpr index_t merge_nelems =
decltype(UnrollNestedTuple(tuple.At(Idx{})))::Size();
return typename arithmetic_sequence_gen<next_seq_val,
next_seq_val + merge_nelems,
1>::type{};
using LowerDimsSequence =
typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
type;
return LowerDimsSequence::Reverse();
}
else
{
......@@ -74,54 +100,140 @@ struct Layout
}
}
template <typename Tuple, typename Descriptor>
constexpr static auto MakeMerges(const Tuple& tuple, Descriptor& desc)
template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx)
{
if constexpr(!IsTupleNested(Tuple<IdxDims...>{}))
{
// Index unrolled to flatten, return shape
return shape;
}
else
{
// Iterate over shape tuple elements:
// 1. If coressponding idx element is tuple then return (will be unrolled)
// 2. If no, pack in tuple. It will be restored during unroll.
auto unrolled_shape_via_idx = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple,
tuple_element_t<i, Tuple<IdxDims...>>>::value)
{
return shape.At(i);
}
else
{
return make_tuple(shape.At(i));
}
},
Number<Tuple<IdxDims...>::Size()>{});
// Unroll and process next step
return UnrollShapeViaIdx(UnrollNestedTuple<0, 1>(unrolled_shape_via_idx),
UnrollNestedTuple<0, 1>(idx));
}
}
template <typename... ShapeDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
DescriptorToMerge& desc)
{
// Reverse each element in tuple
using ReversedUnrolledShape = decltype(ReverseTuple(UnrollNestedTuple(shape)));
const auto merge_elems = ReversedUnrolledShape{};
// Generate reverted indexes (column major traverse)
using MergeElemsSequence =
typename arithmetic_sequence_gen<0, ReversedUnrolledShape::Size(), 1>::type;
const auto lower_dims = make_tuple(MergeElemsSequence::Reverse());
const auto upper_dims = make_tuple(Sequence<0>{});
// Merge to 1d
return transform_tensor_descriptor(
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
}
template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto
MakeMerges(const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
{
const auto transforms = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple, tuple_element_t<i, Tuple>>::value)
// Compare Idx with shape
if constexpr(is_detected<is_tuple,
tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
!is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value)
{
const auto merge_elems = UnrollNestedTuple(tuple.At(i));
// If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major
const auto merge_elems = ReverseTuple(UnrollNestedTuple(shape.At(i)));
return make_merge_transform(merge_elems);
}
else
{
return make_pass_through_transform(tuple.At(i));
// If shape element is integer and idx element is tuple, passed idx is wrong
static_assert(
!(!is_detected<is_tuple, tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value),
"Wrong Idx for layout()");
// If shape element has the same type as idx element, then pass through
return make_pass_through_transform(shape.At(i));
}
},
Number<Tuple::Size()>{});
Number<Tuple<ShapeDims...>::Size()>{});
const auto lower_dims =
generate_tuple([&](auto i) { return GenerateLowerDim<Tuple, Number<i>>(tuple); },
Number<Tuple::Size()>{});
const auto upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<Tuple::Size()>{});
generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
Number<Tuple<ShapeDims...>::Size()>{});
const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
Number<Tuple<ShapeDims...>::Size()>{});
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
}
template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr auto TransformDesc(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx) const
{
if constexpr(Tuple<IdxDims...>::Size() == I1)
{
// 1d idx path
return MakeMerge1d(shape, descriptor_);
}
else
{
static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
"Idx rank and Shape rank must be the same (except 1d).");
// Unroll while IdxDims is nested
const auto unrolled_shape_via_idx = UnrollShapeViaIdx(shape, idx);
// Transform correct form of shape
return MakeMerges(unrolled_shape_via_idx, UnrollNestedTuple(idx), descriptor_);
}
}
template <typename LayoutShape, typename LayoutStrides>
static auto MakeDescriptor(const LayoutShape shape, const LayoutStrides strides)
__host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
const auto unrolled_strides = UnrollNestedTuple(strides);
const auto unrolled_shape = UnrollNestedTuple(shape);
if constexpr(ck::is_same_v<LayoutStrides, Tuple<>>)
{
const auto desc = make_naive_tensor_descriptor_packed(unrolled_shape);
return MakeMerges(shape, desc);
// If shape is packed
const auto column_major_packed_strides =
GenerateColumnMajorPackedStrides(unrolled_shape);
return make_naive_tensor_descriptor(unrolled_shape, column_major_packed_strides);
}
else
{
const auto unrolled_strides = UnrollNestedTuple(strides);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
const auto desc = make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
return MakeMerges(shape, desc);
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
}
public:
using Descriptor = remove_cvref_t<decltype(MakeDescriptor(Shape{}, Strides{}))>;
using NaiveDescriptorType = remove_cvref_t<decltype(MakeNaiveDescriptor(Shape{}, Strides{}))>;
/**
* \brief Layout constructor.
......@@ -131,67 +243,221 @@ struct Layout
* \return Layout object.
*/
__host__ __device__ Layout() = delete;
__host__ __device__ Layout(const Shape shape, const Strides strides) : descriptor_{}
__host__ __device__ Layout(const Shape& shape, const Strides& strides) : descriptor_{}
{
if constexpr(!Descriptor::IsKnownAtCompileTime())
// Construct if runtime mode
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
{
descriptor_ = MakeDescriptor(shape, strides);
// Keep only shape, strides are not need for transforms
shape_ = shape;
descriptor_ = MakeNaiveDescriptor(shape, strides);
}
}
__host__ __device__ Layout(const Shape shape) : descriptor_{}
__host__ __device__ Layout(const Shape& shape) : descriptor_{}
{
if constexpr(!Descriptor::IsKnownAtCompileTime())
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
{
descriptor_ = MakeDescriptor(shape, Strides{});
shape_ = shape;
descriptor_ = MakeNaiveDescriptor(shape, Strides{});
}
}
// Returns real offset to element
template <typename Tuple>
__host__ __device__ constexpr index_t operator()(const Tuple Idx) const
/**
* \brief Returns real offset to element as const in runtime.
*
* \tparam Idxs Tuple of indexes.
* \return Calculated offset as const.
*/
template <typename Idxs>
__host__ __device__ constexpr index_t operator()() const
{
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
}
/**
* \brief Returns real offset to element in runtime.
*
* \tparam Idxs Tuple of indexes.
* \return Calculated offset.
*/
template <typename Idxs>
__host__ __device__ constexpr index_t operator()()
{
return descriptor_.CalculateOffset(Idx);
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
}
template <typename Tuple>
__host__ __device__ constexpr index_t operator()(const Tuple Idx)
/**
* \brief Returns real offset to element in compile time.
*
* \param Idx Tuple of indexes.
* \return Calculated offset.
*/
template <typename... Ts>
__host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
{
return descriptor_.CalculateOffset(Idx);
// Static to construct transformed_desc only once
static const auto transformed_desc = TransformDesc(shape_, Idx);
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
}
// Upper dim getter
/**
* \brief Length getter (product if tuple) as const.
*
* \tparam IDim Tuple of indexes or index.
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t GetLength() const
{
return descriptor_.GetLength(Number<IDim>{});
const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
{
using UnrolledElement = decltype(UnrollNestedTuple(elem));
return TupleReduce<I0.value, UnrolledElement::Size()>(
[](auto x, auto y) { return x * y; }, UnrolledElement{});
}
else
{
return elem;
}
}
/**
* \brief Length getter (product if tuple).
*
* \tparam IDim Tuple of indexes or index.
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t GetLength()
{
return descriptor_.GetLength(Number<IDim>{});
const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
{
using UnrolledElement = decltype(UnrollNestedTuple(elem));
return TupleReduce<I0.value, UnrolledElement::Size()>(
[](auto x, auto y) { return x * y; }, UnrolledElement{});
}
else
{
return elem;
}
}
/**
* \brief Layout size getter (product of shape) as const.
*
* \return Calculated size.
*/
__host__ __device__ constexpr index_t GetLength() const
{
using UnrolledShape = decltype(UnrollNestedTuple(shape_));
return TupleReduce<I0.value, UnrolledShape::Size()>([](auto x, auto y) { return x * y; },
UnrolledShape{});
}
/**
* \brief Layout size getter (product of shape).
*
* \return Calculated size.
*/
__host__ __device__ constexpr index_t GetLength()
{
using UnrolledShape = decltype(UnrollNestedTuple(shape_));
return TupleReduce<I0.value, UnrolledShape::Size()>([](auto x, auto y) { return x * y; },
UnrolledShape{});
}
/**
* \brief Dimension getter as const.
*
* \tparam IDim Dimension idx.
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr auto Get() const
{
const auto elem = shape_.At(Number<IDim>{});
return elem;
}
/**
* \brief Dimension getter.
*
* \tparam IDim Dimension idx.
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr auto Get()
{
const auto elem = shape_.At(Number<IDim>{});
return elem;
}
private:
Descriptor descriptor_;
NaiveDescriptorType descriptor_;
Shape shape_;
};
// Upper dim getter
template <index_t idx, typename L>
index_t size(L layout)
// Layout helpers
// Length getter (product if tuple)
template <index_t idx, typename Shape, typename Strides>
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
{
return layout.template GetLength<idx>();
}
// Get shape size (product of dims if tuple)
template <typename... ShapeDims>
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
{
using UnrolledShape = decltype(UnrollNestedTuple(shape));
return TupleReduce<0, UnrolledShape::Size()>([](auto x, auto y) { return x * y; },
UnrolledShape{});
}
// Get dim size (could be returned from get function)
template <typename T>
__host__ __device__ T constexpr size(const T& dim)
{
return dim;
}
// Get layout size (product of shapes)
template <typename Shape, typename Strides>
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
{
return layout.GetLength();
}
// Get shape element size
template <index_t idx, typename... ShapeDims>
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
{
return size(shape.At(Number<idx>{}));
}
// Dim getter (tuple if tuple)
template <index_t idx, typename Shape, typename Strides>
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
{
return layout.template Get<idx>();
}
template <typename Shape, typename Strides>
Layout<Shape, Strides> make_layout(const Shape& shape, const Strides& strides)
__host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& shape,
const Strides& strides)
{
return Layout<Shape, Strides>(shape, strides);
}
template <typename Shape>
Layout<Shape> make_layout(const Shape& shape)
__host__ __device__ constexpr Layout<Shape> make_layout(const Shape& shape)
{
return Layout<Shape>(shape);
}
......
......@@ -5,6 +5,7 @@
#include "functional4.hpp"
#include "tuple.hpp"
#include "is_detected.hpp"
namespace ck {
......@@ -42,6 +43,13 @@ __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tup
ty);
}
// Support any number of tuples to concat (also 1)
template <typename... X>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx)
{
return tx;
}
template <typename... X, typename... Tuples>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuples&... tuples)
{
......@@ -93,18 +101,69 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <typename T>
// By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1>
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element)
{
return element;
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
__host__ __device__ constexpr auto UnrollNestedTuple(const T& element)
{
return make_tuple(element);
}
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element) { return element; }
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
{
if constexpr(Depth == MaxDepth)
{
return tuple;
}
else
{
return unpack(
[&](auto&&... ts) {
return concat_tuple(UnrollNestedTuple<Depth + 1, MaxDepth>(ts)...);
},
tuple);
}
}
template <typename... Ts>
__host__ __device__ constexpr auto ReverseTuple(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
using Idx = Number<Tuple<Ts...>::Size() - i - 1>;
return tuple.At(Idx{});
},
Number<Tuple<Ts...>::Size()>{});
}
// Reduce tuple values in specific range using Function
template <index_t Idx, index_t End, typename F, typename... Ts>
__host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
{
static_assert(Idx < End, "Wrong parameters for TupleReduce");
if constexpr(Idx + 1 == End)
{
return tuple.At(Number<Idx>{});
}
else
{
return f(tuple.At(Number<Idx>{}), TupleReduce<Idx + 1, End>(f, tuple));
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename... Ts>
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
__host__ __device__ constexpr auto IsTupleNested(const Tuple<Ts...>&)
{
return unpack([&](auto&&... ts) { return concat_tuple(UnrollNestedTuple(ts)...); }, tuple);
return (is_detected<is_tuple, Ts>::value || ...);
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment