Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
81b79a77
Commit
81b79a77
authored
Nov 24, 2023
by
Bartlomiej Kocot
Browse files
Extend functionality
parent
1e276c57
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
554 additions
and
127 deletions
+554
-127
example/64_tensor_transforms/tensor_transform.cpp
example/64_tensor_transforms/tensor_transform.cpp
+97
-38
example/64_tensor_transforms/tensor_transform_using_wrapper.cpp
...e/64_tensor_transforms/tensor_transform_using_wrapper.cpp
+78
-35
example/64_tensor_transforms/tensor_transform_wrapper.hpp
example/64_tensor_transforms/tensor_transform_wrapper.hpp
+316
-50
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+63
-4
No files found.
example/64_tensor_transforms/tensor_transform.cpp
View file @
81b79a77
...
@@ -15,12 +15,25 @@
...
@@ -15,12 +15,25 @@
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
static
constexpr
auto
I2
=
ck
::
Number
<
2
>
{};
using
DataType
=
int
;
using
DataType
=
int
;
template
<
typename
Desc
>
template
<
typename
Desc
>
void
Print
(
const
Desc
&
desc
)
void
Print
1d
(
const
Desc
&
desc
)
{
{
std
::
cout
<<
"Print1d"
<<
std
::
endl
;
for
(
ck
::
index_t
w
=
0
;
w
<
desc
.
GetLength
(
I0
);
w
++
)
{
std
::
cout
<<
desc
.
CalculateOffset
(
ck
::
make_tuple
(
w
))
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
template
<
typename
Desc
>
void
Print2d
(
const
Desc
&
desc
)
{
std
::
cout
<<
"Print2d"
<<
std
::
endl
;
for
(
ck
::
index_t
h
=
0
;
h
<
desc
.
GetLength
(
I0
);
h
++
)
for
(
ck
::
index_t
h
=
0
;
h
<
desc
.
GetLength
(
I0
);
h
++
)
{
{
for
(
ck
::
index_t
w
=
0
;
w
<
desc
.
GetLength
(
I1
);
w
++
)
for
(
ck
::
index_t
w
=
0
;
w
<
desc
.
GetLength
(
I1
);
w
++
)
...
@@ -31,61 +44,107 @@ void Print(const Desc& desc)
...
@@ -31,61 +44,107 @@ void Print(const Desc& desc)
}
}
}
}
template
<
typename
Desc
>
void
Print3dCustom
(
const
Desc
&
desc
)
{
std
::
cout
<<
"Print3dCustom"
<<
std
::
endl
;
for
(
ck
::
index_t
d
=
0
;
d
<
desc
.
GetLength
(
I0
);
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
desc
.
GetLength
(
I1
);
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
desc
.
GetLength
(
I2
);
w
++
)
{
std
::
cout
<<
desc
.
CalculateOffset
(
ck
::
make_tuple
(
d
,
h
,
w
))
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
std
::
endl
;
}
}
int
main
()
int
main
()
{
{
// Tensor descriptor traverse in row-major (need to reverse dims)
std
::
cout
<<
"Note: Tensor descriptor traverse in row-major"
<<
std
::
endl
;
// Basic descriptor 0, 1, 2, ... 30, 31
// Basic descriptor 0, 1, 2, ... 30, 31
// (dims:4,8 strides:1,1)
// (dims:4,8 strides:1,4)
const
auto
desc_4x8_s1x1
=
ck
::
make_naive_tensor_descriptor_packed
(
ck
::
make_tuple
(
4
,
8
));
const
auto
desc_4x8_s1x4
=
std
::
cout
<<
"dims:4,8 strides:1,1"
<<
std
::
endl
;
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
ck
::
Number
<
4
>
{},
ck
::
Number
<
8
>
{}),
Print
(
desc_4x8_s1x1
);
ck
::
make_tuple
(
ck
::
Number
<
1
>
{},
ck
::
Number
<
4
>
{}));
std
::
cout
<<
"dims:4,8 strides:1,4"
<<
std
::
endl
;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31
Print2d
(
desc_4x8_s1x4
);
// dims:4,(4,2) strides:2,(8,1)
const
auto
desc_4x4x2_s2x8x1
=
using
Cord0x0Type
=
ck
::
Tuple
<
ck
::
Number
<
0
>
,
ck
::
Number
<
0
>>
;
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
4
,
4
,
2
),
ck
::
make_tuple
(
2
,
8
,
1
));
constexpr
ck
::
index_t
offset_0x0
=
desc_4x8_s1x4
.
CalculateOffset
(
Cord0x0Type
{});
std
::
cout
<<
"Constexpr calculated [0, 0] offset:"
<<
offset_0x0
<<
std
::
endl
;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:4,(2,4) strides:2,(1,8)
const
auto
desc_4x2x4_s2x1x8
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
4
,
2
,
4
),
ck
::
make_tuple
(
2
,
1
,
8
));
// Transform to 2d
// Transform to 2d
const
auto
desc_4x
4x2
_s2x
8x1
_merged
=
ck
::
transform_tensor_descriptor
(
const
auto
desc_4x
2x4
_s2x
1x8
_merged
=
ck
::
transform_tensor_descriptor
(
desc_4x
4x2
_s2x
8x1
,
desc_4x
2x4
_s2x
1x8
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
4
),
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
4
),
ck
::
make_merge_transform
(
ck
::
make_tuple
(
4
,
2
))),
ck
::
make_merge_transform
(
ck
::
make_tuple
(
4
,
2
))),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
,
2
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
2
,
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
std
::
cout
<<
"dims:4,(
4,2
) strides:2,(
8,1
)"
<<
std
::
endl
;
std
::
cout
<<
"dims:4,(
2,4
) strides:2,(
1,8
)"
<<
std
::
endl
;
Print
(
desc_4x
4x2
_s2x
8x1
_merged
);
Print
2d
(
desc_4x
2x4
_s2x
1x8
_merged
);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31
(compile-time descriptor)
// dims:(2,2),(
4,2
) strides:(
4,1),(8,2
)
// dims:(2,2),(
2,4
) strides:(
(1,4),(2,8
)
const
auto
desc_2x2x
4x2_s4x
1x
8
x2
=
const
auto
desc_2x2x
2x4_s
1x
4
x2
x8
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
2
,
2
,
4
,
2
),
ck
::
make_tuple
(
4
,
1
,
8
,
2
));
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
2
,
2
,
2
,
4
),
ck
::
make_tuple
(
1
,
4
,
2
,
8
));
// Transform to 2d
// Transform to 2d
const
auto
desc_2x2x
4x2_s
4x
1
x8
x2
_double_merged
=
ck
::
transform_tensor_descriptor
(
const
auto
desc_2x2x
2x4_s1x
4x
2
x8_double_merged
_2d
=
ck
::
transform_tensor_descriptor
(
desc_2x2x
4x2_s4x
1x
8
x2
,
desc_2x2x
2x4_s
1x
4
x2
x8
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
2
,
2
)),
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
2
,
2
)),
ck
::
make_merge_transform
(
ck
::
make_tuple
(
4
,
2
))),
ck
::
make_merge_transform
(
ck
::
make_tuple
(
4
,
2
))),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
1
>
{},
ck
::
Sequence
<
2
,
3
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
3
,
2
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
std
::
cout
<<
"dims:(2,2),(4,2) strides:(4,1),(8,2)"
<<
std
::
endl
;
// Transform to 3d
Print
(
desc_2x2x4x2_s4x1x8x2_double_merged
);
const
auto
desc_2x2x2x4_s1x4x2x8_double_merged_3d
=
ck
::
transform_tensor_descriptor
(
desc_2x2x2x4_s1x4x2x8
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
2
),
ck
::
make_pass_through_transform
(
2
),
ck
::
make_merge_transform
(
ck
::
make_tuple
(
4
,
2
))),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
3
,
2
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
2
>
{}));
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31
std
::
cout
<<
"dims:(2,2),(2,4) strides:(1,4),(2,8)"
<<
std
::
endl
;
// dims:((2,2),4),2 strides:((4,1),8),2
Print2d
(
desc_2x2x2x4_s1x4x2x8_double_merged_2d
);
Print3dCustom
(
desc_2x2x2x4_s1x4x2x8_double_merged_3d
);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:((2,2),2),4 strides:((1,4),2),8
// Transform to 2d
// Transform to 2d
const
auto
desc_2x2x4x2_s4x1x8x2_merged
=
ck
::
transform_tensor_descriptor
(
const
auto
desc_2x2x2x4_s1x4x2x8_nested
=
desc_2x2x4x2_s4x1x8x2
,
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
2
,
2
,
2
,
4
),
ck
::
make_tuple
(
1
,
4
,
2
,
8
));
const
auto
desc_2x2x2x4_s1x4x2x8_nested_merged_3d
=
ck
::
transform_tensor_descriptor
(
desc_2x2x2x4_s1x4x2x8_nested
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
2
,
2
)),
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
2
,
2
)),
ck
::
make_pass_through_transform
(
4
),
ck
::
make_pass_through_transform
(
2
),
ck
::
make_pass_through_transform
(
2
)),
ck
::
make_pass_through_transform
(
4
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
1
>
{},
ck
::
Sequence
<
2
>
{},
ck
::
Sequence
<
3
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
2
>
{},
ck
::
Sequence
<
3
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
2
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
2
>
{}));
const
auto
desc_2x2x4x2_s4x1x8x2_nested_merged
=
ck
::
transform_tensor_descriptor
(
const
auto
desc_2x2x2x4_s1x4x2x8_nested_merged_1d
=
ck
::
transform_tensor_descriptor
(
desc_2x2x4x2_s4x1x8x2_merged
,
desc_2x2x2x4_s1x4x2x8_nested
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
4
,
4
)),
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
4
,
2
,
2
,
2
))),
ck
::
make_pass_through_transform
(
2
)),
ck
::
make_tuple
(
ck
::
Sequence
<
3
,
2
,
1
,
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
1
>
{},
ck
::
Sequence
<
2
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{}));
const
auto
desc_2x2x2x4_s1x4x2x8_nested_merged_2d
=
ck
::
transform_tensor_descriptor
(
desc_2x2x2x4_s1x4x2x8_nested_merged_3d
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
2
,
4
)),
ck
::
make_pass_through_transform
(
4
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
2
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
std
::
cout
<<
"dims:((2,2),4),2 strides:((4,1),8),2"
<<
std
::
endl
;
Print
(
desc_2x2x4x2_s4x1x8x2_nested_merged
);
std
::
cout
<<
"dims:((2,2),2),4 strides:((1,4),2),8"
<<
std
::
endl
;
Print1d
(
desc_2x2x2x4_s1x4x2x8_nested_merged_1d
);
Print2d
(
desc_2x2x2x4_s1x4x2x8_nested_merged_2d
);
Print3dCustom
(
desc_2x2x2x4_s1x4x2x8_nested_merged_3d
);
return
0
;
return
0
;
}
}
example/64_tensor_transforms/tensor_transform_using_wrapper.cpp
View file @
81b79a77
...
@@ -14,8 +14,20 @@
...
@@ -14,8 +14,20 @@
using
DataType
=
int
;
using
DataType
=
int
;
template
<
typename
Layout
>
template
<
typename
Layout
>
void
Print
(
const
Layout
&
layout
)
void
Print
1d
(
const
Layout
&
layout
)
{
{
std
::
cout
<<
"Print1d"
<<
std
::
endl
;
for
(
ck
::
index_t
w
=
0
;
w
<
ck
::
tensor_transform_wrapper
::
size
(
layout
);
w
++
)
{
std
::
cout
<<
layout
(
ck
::
make_tuple
(
w
))
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
template
<
typename
Layout
>
void
Print2d
(
const
Layout
&
layout
)
{
std
::
cout
<<
"Print2d"
<<
std
::
endl
;
for
(
ck
::
index_t
h
=
0
;
h
<
ck
::
tensor_transform_wrapper
::
size
<
0
>
(
layout
);
h
++
)
for
(
ck
::
index_t
h
=
0
;
h
<
ck
::
tensor_transform_wrapper
::
size
<
0
>
(
layout
);
h
++
)
{
{
for
(
ck
::
index_t
w
=
0
;
w
<
ck
::
tensor_transform_wrapper
::
size
<
1
>
(
layout
);
w
++
)
for
(
ck
::
index_t
w
=
0
;
w
<
ck
::
tensor_transform_wrapper
::
size
<
1
>
(
layout
);
w
++
)
...
@@ -26,53 +38,84 @@ void Print(const Layout& layout)
...
@@ -26,53 +38,84 @@ void Print(const Layout& layout)
}
}
}
}
// Print in (x,y),z pattern
template
<
typename
Layout
>
void
Print3dCustom
(
const
Layout
&
layout
)
{
std
::
cout
<<
"Print3dCustom"
<<
std
::
endl
;
for
(
ck
::
index_t
d
=
0
;
d
<
ck
::
tensor_transform_wrapper
::
size
<
0
>
(
ck
::
tensor_transform_wrapper
::
get
<
0
>
(
layout
));
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
ck
::
tensor_transform_wrapper
::
size
<
1
>
(
ck
::
tensor_transform_wrapper
::
get
<
0
>
(
layout
));
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
ck
::
tensor_transform_wrapper
::
size
<
1
>
(
layout
);
w
++
)
{
std
::
cout
<<
layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d
,
h
),
w
))
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
std
::
endl
;
}
}
int
main
()
int
main
()
{
{
// Layout traverse in row-major
std
::
cout
<<
"Note: Layout traverse in column-major"
<<
std
::
endl
;
// Basic descriptor 0, 1, 2, ... 30, 31 (runtime descriptor)
// Basic descriptor 0, 1, 2, ... 30, 31 (runtime descriptor)
// (dims:4,8 strides:1,
1
)
// (dims:4,8 strides:1,
4
)
const
auto
shape_4x8
=
ck
::
make_tuple
(
4
,
8
);
const
auto
shape_4x8
=
ck
::
make_tuple
(
4
,
8
);
const
auto
layout_4x8_s1x1
=
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_4x8
);
const
auto
layout_4x8_s1x4
=
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_4x8
);
std
::
cout
<<
"dims:4,8 strides:1,1"
<<
std
::
endl
;
std
::
cout
<<
"dims:4,8 strides:1,4"
<<
std
::
endl
;
Print
(
layout_4x8_s1x1
);
Print2d
(
layout_4x8_s1x4
);
using
Cord0x0Type
=
ck
::
Tuple
<
ck
::
Number
<
0
>
,
ck
::
Number
<
0
>>
;
constexpr
ck
::
index_t
offset_0x0
=
layout_4x8_s1x4
.
template
operator
()
<
Cord0x0Type
>();
std
::
cout
<<
"Constexpr calculated [0, 0] offset:"
<<
offset_0x0
<<
std
::
endl
;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:4,(
4,2
) strides:2,(
8,1
)
// dims:4,(
2,4
) strides:2,(
1,8
)
const
auto
shape_4x
4x2
=
const
auto
shape_4x
2x4
=
ck
::
make_tuple
(
ck
::
Number
<
4
>
{},
ck
::
make_tuple
(
ck
::
Number
<
4
>
{},
ck
::
Number
<
2
>
{}));
ck
::
make_tuple
(
ck
::
Number
<
4
>
{},
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
Number
<
4
>
{}));
const
auto
strides_s2x
8x1
=
const
auto
strides_s2x
1x8
=
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
make_tuple
(
ck
::
Number
<
8
>
{},
ck
::
Number
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
make_tuple
(
ck
::
Number
<
1
>
{},
ck
::
Number
<
8
>
{}));
const
auto
layout_4x
4x2
_s2x
8x1
=
const
auto
layout_4x
2x4
_s2x
1x8
=
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_4x
4x2
,
strides_s2x
8x1
);
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_4x
2x4
,
strides_s2x
1x8
);
std
::
cout
<<
"dims:4,(
4,2
) strides:2,(
8,1
)"
<<
std
::
endl
;
std
::
cout
<<
"dims:4,(
2,4
) strides:2,(
1,8
)"
<<
std
::
endl
;
Print
(
layout_4x
4x2
_s2x
8x1
);
Print
2d
(
layout_4x
2x4
_s2x
1x8
);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:(2,2),(
4,2
) strides:((
4,1),(8,2
)
// dims:(2,2),(
2,4
) strides:((
1,4),(2,8
)
const
auto
shape_2x2x
4x2
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
Number
<
2
>
{}),
const
auto
shape_2x2x
2x4
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
Number
<
2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
4
>
{},
ck
::
Number
<
2
>
{}));
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
Number
<
4
>
{}));
const
auto
strides_s
4x
1x
8
x2
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
4
>
{},
ck
::
Number
<
1
>
{}),
const
auto
strides_s1x
4
x2
x8
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
1
>
{},
ck
::
Number
<
4
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
8
>
{},
ck
::
Number
<
2
>
{}));
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
Number
<
8
>
{}));
static
const
auto
layout_2x2x
4x2_s4x
1x
8
x2
=
static
const
auto
layout_2x2x
2x4_s
1x
4
x2
x8
=
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_2x2x
4x2
,
strides_s
4x
1x
8
x2
);
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_2x2x
2x4
,
strides_s1x
4
x2
x8
);
std
::
cout
<<
"dims:(2,2),(4,2) strides:(4,1),(8,2)"
<<
std
::
endl
;
std
::
cout
<<
"dims:(2,2),(2,4) strides:(1,4),(2,8)"
<<
std
::
endl
;
Print
(
layout_2x2x4x2_s4x1x8x2
);
Print2d
(
layout_2x2x2x4_s1x4x2x8
);
Print3dCustom
(
layout_2x2x2x4_s1x4x2x8
);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:((2,2),
4
),
2
strides:((
4,1),8),2
// dims:((2,2),
2
),
4
strides:((
1,4),2),8
// Transform to 2d
// Transform to 2d
const
auto
shape_2x2x4x2_nested
=
ck
::
make_tuple
(
const
auto
shape_2x2x2x4_nested
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
Number
<
2
>
{}),
ck
::
Number
<
4
>
{}),
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
2
>
{},
ck
::
Number
<
2
>
{}),
ck
::
Number
<
2
>
{}),
ck
::
Number
<
2
>
{});
ck
::
Number
<
4
>
{});
const
auto
strides_s4x1x8x2_nested
=
ck
::
make_tuple
(
const
auto
strides_s1x4x2x8_nested
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
4
>
{},
ck
::
Number
<
1
>
{}),
ck
::
Number
<
8
>
{}),
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
1
>
{},
ck
::
Number
<
4
>
{}),
ck
::
Number
<
2
>
{}),
ck
::
Number
<
2
>
{});
ck
::
Number
<
8
>
{});
static
const
auto
layout_2x2x4x2_s4x1x8x2_nested
=
static
const
auto
layout_2x2x2x4_s1x4x2x8_nested
=
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_2x2x4x2_nested
,
strides_s4x1x8x2_nested
);
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_2x2x2x4_nested
,
strides_s1x4x2x8_nested
);
std
::
cout
<<
"dims:((2,2),4),2 strides:((4,1),8),2"
<<
std
::
endl
;
std
::
cout
<<
"dims:((2,2),2),4 strides:((1,4),2),8"
<<
std
::
endl
;
Print
(
layout_2x2x4x2_s4x1x8x2_nested
);
Print1d
(
layout_2x2x2x4_s1x4x2x8_nested
);
Print2d
(
layout_2x2x2x4_s1x4x2x8_nested
);
Print3dCustom
(
layout_2x2x2x4_s1x4x2x8_nested
);
return
0
;
return
0
;
}
}
example/64_tensor_transforms/tensor_transform_wrapper.hpp
View file @
81b79a77
...
@@ -36,19 +36,44 @@ template <typename Shape, typename Strides = Tuple<>>
...
@@ -36,19 +36,44 @@ template <typename Shape, typename Strides = Tuple<>>
struct
Layout
struct
Layout
{
{
private:
private:
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
template
<
typename
T
>
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
Tuple
,
typename
Idx
>
// Generate packed (column-major) strides if not passed
constexpr
static
auto
GenerateLowerDim
(
Tuple
tuple
)
template
<
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateColumnMajorPackedStrides
(
const
Tuple
<
Ts
...
>&
tuple
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
.
value
==
0
)
{
return
I1
;
}
else
{
return
TupleReduce
<
I0
.
value
,
i
.
value
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
tuple
);
}
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
}
template
<
typename
Idx
,
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateLowerDim
(
const
Tuple
<
Ts
...
>&
tuple
)
{
{
if
constexpr
(
Idx
::
value
==
0
)
if
constexpr
(
Idx
::
value
==
0
)
{
{
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
Idx
::
value
,
Tuple
>>::
value
)
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>
>>::
value
)
{
{
constexpr
index_t
merge_nelems
=
constexpr
index_t
merge_nelems
=
decltype
(
UnrollNestedTuple
(
tuple
.
At
(
Idx
{})))
::
Size
();
decltype
(
UnrollNestedTuple
(
tuple
.
At
(
Idx
{})))
::
Size
();
return
typename
arithmetic_sequence_gen
<
0
,
merge_nelems
,
1
>::
type
{};
using
LowerDimsSequence
=
typename
arithmetic_sequence_gen
<
0
,
merge_nelems
,
1
>::
type
;
return
LowerDimsSequence
::
Reverse
();
}
}
else
else
{
{
...
@@ -57,15 +82,16 @@ struct Layout
...
@@ -57,15 +82,16 @@ struct Layout
}
}
else
else
{
{
using
PreviousSeqT
=
decltype
(
GenerateLowerDim
<
Tuple
,
Number
<
Idx
::
value
-
1
>>
(
tuple
));
using
PreviousSeqT
=
decltype
(
GenerateLowerDim
<
Number
<
Idx
::
value
-
1
>>
(
tuple
));
const
auto
next_seq_val
=
PreviousSeqT
::
At
(
PreviousSeqT
::
Size
()
-
1
)
+
1
;
const
auto
next_seq_val
=
PreviousSeqT
::
At
(
I0
)
+
1
;
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
Idx
::
value
,
Tuple
>>::
value
)
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>
>>::
value
)
{
{
constexpr
index_t
merge_nelems
=
constexpr
index_t
merge_nelems
=
decltype
(
UnrollNestedTuple
(
tuple
.
At
(
Idx
{})))
::
Size
();
decltype
(
UnrollNestedTuple
(
tuple
.
At
(
Idx
{})))
::
Size
();
return
typename
arithmetic_sequence_gen
<
next_seq_val
,
using
LowerDimsSequence
=
next_seq_val
+
merge_nelems
,
typename
arithmetic_sequence_gen
<
next_seq_val
,
next_seq_val
+
merge_nelems
,
1
>::
1
>::
type
{};
type
;
return
LowerDimsSequence
::
Reverse
();
}
}
else
else
{
{
...
@@ -74,54 +100,140 @@ struct Layout
...
@@ -74,54 +100,140 @@ struct Layout
}
}
}
}
template
<
typename
Tuple
,
typename
Descriptor
>
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
constexpr
static
auto
MakeMerges
(
const
Tuple
&
tuple
,
Descriptor
&
desc
)
__host__
__device__
constexpr
static
auto
UnrollShapeViaIdx
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
)
{
if
constexpr
(
!
IsTupleNested
(
Tuple
<
IdxDims
...
>
{}))
{
// Index unrolled to flatten, return shape
return
shape
;
}
else
{
// Iterate over shape tuple elements:
// 1. If coressponding idx element is tuple then return (will be unrolled)
// 2. If no, pack in tuple. It will be restored during unroll.
auto
unrolled_shape_via_idx
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
,
Tuple
<
IdxDims
...
>>>::
value
)
{
return
shape
.
At
(
i
);
}
else
{
return
make_tuple
(
shape
.
At
(
i
));
}
},
Number
<
Tuple
<
IdxDims
...
>::
Size
()
>
{});
// Unroll and process next step
return
UnrollShapeViaIdx
(
UnrollNestedTuple
<
0
,
1
>
(
unrolled_shape_via_idx
),
UnrollNestedTuple
<
0
,
1
>
(
idx
));
}
}
template
<
typename
...
ShapeDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
MakeMerge1d
(
const
Tuple
<
ShapeDims
...
>&
shape
,
DescriptorToMerge
&
desc
)
{
// Reverse each element in tuple
using
ReversedUnrolledShape
=
decltype
(
ReverseTuple
(
UnrollNestedTuple
(
shape
)));
const
auto
merge_elems
=
ReversedUnrolledShape
{};
// Generate reverted indexes (column major traverse)
using
MergeElemsSequence
=
typename
arithmetic_sequence_gen
<
0
,
ReversedUnrolledShape
::
Size
(),
1
>::
type
;
const
auto
lower_dims
=
make_tuple
(
MergeElemsSequence
::
Reverse
());
const
auto
upper_dims
=
make_tuple
(
Sequence
<
0
>
{});
// Merge to 1d
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
merge_elems
)),
lower_dims
,
upper_dims
);
}
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
MakeMerges
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
,
DescriptorToMerge
&
desc
)
{
{
const
auto
transforms
=
generate_tuple
(
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
,
Tuple
>>::
value
)
// Compare Idx with shape
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
,
Tuple
<
ShapeDims
...
>>>::
value
&&
!
is_detected
<
is_tuple
,
tuple_element_t
<
i
,
Tuple
<
IdxDims
...
>>>::
value
)
{
{
const
auto
merge_elems
=
UnrollNestedTuple
(
tuple
.
At
(
i
));
// If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major
const
auto
merge_elems
=
ReverseTuple
(
UnrollNestedTuple
(
shape
.
At
(
i
)));
return
make_merge_transform
(
merge_elems
);
return
make_merge_transform
(
merge_elems
);
}
}
else
else
{
{
return
make_pass_through_transform
(
tuple
.
At
(
i
));
// If shape element is integer and idx element is tuple, passed idx is wrong
static_assert
(
!
(
!
is_detected
<
is_tuple
,
tuple_element_t
<
i
,
Tuple
<
ShapeDims
...
>>>::
value
&&
is_detected
<
is_tuple
,
tuple_element_t
<
i
,
Tuple
<
IdxDims
...
>>>::
value
),
"Wrong Idx for layout()"
);
// If shape element has the same type as idx element, then pass through
return
make_pass_through_transform
(
shape
.
At
(
i
));
}
}
},
},
Number
<
Tuple
::
Size
()
>
{});
Number
<
Tuple
<
ShapeDims
...
>
::
Size
()
>
{});
const
auto
lower_dims
=
const
auto
lower_dims
=
generate_tuple
([
&
](
auto
i
)
{
return
GenerateLowerDim
<
Tuple
,
Number
<
i
>>
(
tupl
e
);
},
generate_tuple
([
&
](
auto
i
)
{
return
GenerateLowerDim
<
Number
<
i
>>
(
shap
e
);
},
Number
<
Tuple
::
Size
()
>
{});
Number
<
Tuple
<
ShapeDims
...
>
::
Size
()
>
{});
const
auto
upper_dims
=
const
auto
upper_dims
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
Tuple
::
Size
()
>
{});
Number
<
Tuple
<
ShapeDims
...
>
::
Size
()
>
{});
return
transform_tensor_descriptor
(
desc
,
transforms
,
lower_dims
,
upper_dims
);
return
transform_tensor_descriptor
(
desc
,
transforms
,
lower_dims
,
upper_dims
);
}
}
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
auto
TransformDesc
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
)
const
{
if
constexpr
(
Tuple
<
IdxDims
...
>::
Size
()
==
I1
)
{
// 1d idx path
return
MakeMerge1d
(
shape
,
descriptor_
);
}
else
{
static_assert
(
Tuple
<
ShapeDims
...
>::
Size
()
==
Tuple
<
IdxDims
...
>::
Size
(),
"Idx rank and Shape rank must be the same (except 1d)."
);
// Unroll while IdxDims is nested
const
auto
unrolled_shape_via_idx
=
UnrollShapeViaIdx
(
shape
,
idx
);
// Transform correct form of shape
return
MakeMerges
(
unrolled_shape_via_idx
,
UnrollNestedTuple
(
idx
),
descriptor_
);
}
}
template
<
typename
LayoutShape
,
typename
LayoutStrides
>
template
<
typename
LayoutShape
,
typename
LayoutStrides
>
static
auto
MakeDescriptor
(
const
LayoutShape
shape
,
const
LayoutStrides
strides
)
__host__
__device__
static
auto
MakeNaiveDescriptor
(
const
LayoutShape
&
shape
,
const
LayoutStrides
&
strides
)
{
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
const
auto
unrolled_strides
=
UnrollNestedTuple
(
strides
);
if
constexpr
(
ck
::
is_same_v
<
LayoutStrides
,
Tuple
<>>
)
if
constexpr
(
ck
::
is_same_v
<
LayoutStrides
,
Tuple
<>>
)
{
{
const
auto
desc
=
make_naive_tensor_descriptor_packed
(
unrolled_shape
);
// If shape is packed
return
MakeMerges
(
shape
,
desc
);
const
auto
column_major_packed_strides
=
GenerateColumnMajorPackedStrides
(
unrolled_shape
);
return
make_naive_tensor_descriptor
(
unrolled_shape
,
column_major_packed_strides
);
}
}
else
else
{
{
const
auto
unrolled_strides
=
UnrollNestedTuple
(
strides
);
static_assert
(
unrolled_shape
.
Size
()
==
unrolled_strides
.
Size
(),
static_assert
(
unrolled_shape
.
Size
()
==
unrolled_strides
.
Size
(),
"Size of strides and shape are not consistent."
);
"Size of strides and shape are not consistent."
);
const
auto
desc
=
make_naive_tensor_descriptor
(
unrolled_shape
,
unrolled_strides
);
return
make_naive_tensor_descriptor
(
unrolled_shape
,
unrolled_strides
);
return
MakeMerges
(
shape
,
desc
);
}
}
}
}
public:
public:
using
Descriptor
=
remove_cvref_t
<
decltype
(
MakeDescriptor
(
Shape
{},
Strides
{}))
>
;
using
Naive
Descriptor
Type
=
remove_cvref_t
<
decltype
(
Make
Naive
Descriptor
(
Shape
{},
Strides
{}))
>
;
/**
/**
* \brief Layout constructor.
* \brief Layout constructor.
...
@@ -131,67 +243,221 @@ struct Layout
...
@@ -131,67 +243,221 @@ struct Layout
* \return Layout object.
* \return Layout object.
*/
*/
__host__
__device__
Layout
()
=
delete
;
__host__
__device__
Layout
()
=
delete
;
__host__
__device__
Layout
(
const
Shape
shape
,
const
Strides
strides
)
:
descriptor_
{}
__host__
__device__
Layout
(
const
Shape
&
shape
,
const
Strides
&
strides
)
:
descriptor_
{}
{
{
if
constexpr
(
!
Descriptor
::
IsKnownAtCompileTime
())
// Construct if runtime mode
if
constexpr
(
!
NaiveDescriptorType
::
IsKnownAtCompileTime
())
{
{
descriptor_
=
MakeDescriptor
(
shape
,
strides
);
// Keep only shape, strides are not need for transforms
shape_
=
shape
;
descriptor_
=
MakeNaiveDescriptor
(
shape
,
strides
);
}
}
}
}
__host__
__device__
Layout
(
const
Shape
shape
)
:
descriptor_
{}
__host__
__device__
Layout
(
const
Shape
&
shape
)
:
descriptor_
{}
{
{
if
constexpr
(
!
Descriptor
::
IsKnownAtCompileTime
())
if
constexpr
(
!
Naive
Descriptor
Type
::
IsKnownAtCompileTime
())
{
{
descriptor_
=
MakeDescriptor
(
shape
,
Strides
{});
shape_
=
shape
;
descriptor_
=
MakeNaiveDescriptor
(
shape
,
Strides
{});
}
}
}
}
// Returns real offset to element
/**
template
<
typename
Tuple
>
* \brief Returns real offset to element as const in runtime.
__host__
__device__
constexpr
index_t
operator
()(
const
Tuple
Idx
)
const
*
* \tparam Idxs Tuple of indexes.
* \return Calculated offset as const.
*/
template
<
typename
Idxs
>
__host__
__device__
constexpr
index_t
operator
()()
const
{
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{}));
using
UnrolledIdx
=
decltype
(
UnrollNestedTuple
(
Idxs
{}));
return
TransformedDesc
{}.
CalculateOffset
(
UnrolledIdx
{});
}
/**
* \brief Returns real offset to element in runtime.
*
* \tparam Idxs Tuple of indexes.
* \return Calculated offset.
*/
template
<
typename
Idxs
>
__host__
__device__
constexpr
index_t
operator
()()
{
{
return
descriptor_
.
CalculateOffset
(
Idx
);
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{}));
using
UnrolledIdx
=
decltype
(
UnrollNestedTuple
(
Idxs
{}));
return
TransformedDesc
{}.
CalculateOffset
(
UnrolledIdx
{});
}
}
template
<
typename
Tuple
>
/**
__host__
__device__
constexpr
index_t
operator
()(
const
Tuple
Idx
)
* \brief Returns real offset to element in compile time.
*
* \param Idx Tuple of indexes.
* \return Calculated offset.
*/
template
<
typename
...
Ts
>
__host__
__device__
index_t
operator
()(
const
Tuple
<
Ts
...
>&
Idx
)
const
{
{
return
descriptor_
.
CalculateOffset
(
Idx
);
// Static to construct transformed_desc only once
static
const
auto
transformed_desc
=
TransformDesc
(
shape_
,
Idx
);
return
transformed_desc
.
CalculateOffset
(
UnrollNestedTuple
(
Idx
));
}
}
// Upper dim getter
/**
* \brief Length getter (product if tuple) as const.
*
* \tparam IDim Tuple of indexes or index.
* \return Calculated size.
*/
template
<
index_t
IDim
>
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
GetLength
()
const
__host__
__device__
constexpr
index_t
GetLength
()
const
{
{
return
descriptor_
.
GetLength
(
Number
<
IDim
>
{});
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
IDim
,
Shape
>>::
value
)
{
using
UnrolledElement
=
decltype
(
UnrollNestedTuple
(
elem
));
return
TupleReduce
<
I0
.
value
,
UnrolledElement
::
Size
()
>
(
[](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
UnrolledElement
{});
}
else
{
return
elem
;
}
}
}
/**
* \brief Length getter (product if tuple).
*
* \tparam IDim Tuple of indexes or index.
* \return Calculated size.
*/
template
<
index_t
IDim
>
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
GetLength
()
__host__
__device__
constexpr
index_t
GetLength
()
{
{
return
descriptor_
.
GetLength
(
Number
<
IDim
>
{});
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
IDim
,
Shape
>>::
value
)
{
using
UnrolledElement
=
decltype
(
UnrollNestedTuple
(
elem
));
return
TupleReduce
<
I0
.
value
,
UnrolledElement
::
Size
()
>
(
[](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
UnrolledElement
{});
}
else
{
return
elem
;
}
}
/**
* \brief Layout size getter (product of shape) as const.
*
* \return Calculated size.
*/
__host__
__device__
constexpr
index_t
GetLength
()
const
{
using
UnrolledShape
=
decltype
(
UnrollNestedTuple
(
shape_
));
return
TupleReduce
<
I0
.
value
,
UnrolledShape
::
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
UnrolledShape
{});
}
/**
* \brief Layout size getter (product of shape).
*
* \return Calculated size.
*/
__host__
__device__
constexpr
index_t
GetLength
()
{
using
UnrolledShape
=
decltype
(
UnrollNestedTuple
(
shape_
));
return
TupleReduce
<
I0
.
value
,
UnrolledShape
::
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
UnrolledShape
{});
}
/**
* \brief Dimension getter as const.
*
* \tparam IDim Dimension idx.
* \return Calculated size.
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
auto
Get
()
const
{
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
return
elem
;
}
/**
* \brief Dimension getter.
*
* \tparam IDim Dimension idx.
* \return Calculated size.
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
auto
Get
()
{
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
return
elem
;
}
}
private:
private:
Descriptor
descriptor_
;
NaiveDescriptorType
descriptor_
;
Shape
shape_
;
};
};
// Upper dim getter
// Layout helpers
template
<
index_t
idx
,
typename
L
>
// Length getter (product if tuple)
index_t
size
(
L
layout
)
template
<
index_t
idx
,
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
index_t
size
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
{
return
layout
.
template
GetLength
<
idx
>();
return
layout
.
template
GetLength
<
idx
>();
}
}
// Get shape size (product of dims if tuple)
template
<
typename
...
ShapeDims
>
__host__
__device__
constexpr
index_t
size
(
const
Tuple
<
ShapeDims
...
>&
shape
)
{
using
UnrolledShape
=
decltype
(
UnrollNestedTuple
(
shape
));
return
TupleReduce
<
0
,
UnrolledShape
::
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
UnrolledShape
{});
}
// Get dim size (could be returned from get function)
template
<
typename
T
>
__host__
__device__
T
constexpr
size
(
const
T
&
dim
)
{
return
dim
;
}
// Get layout size (product of shapes)
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
index_t
size
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
GetLength
();
}
// Get shape element size
template
<
index_t
idx
,
typename
...
ShapeDims
>
__host__
__device__
constexpr
index_t
size
(
const
Tuple
<
ShapeDims
...
>&
shape
)
{
return
size
(
shape
.
At
(
Number
<
idx
>
{}));
}
// Dim getter (tuple if tuple)
template
<
index_t
idx
,
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
get
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
template
Get
<
idx
>();
}
template
<
typename
Shape
,
typename
Strides
>
template
<
typename
Shape
,
typename
Strides
>
Layout
<
Shape
,
Strides
>
make_layout
(
const
Shape
&
shape
,
const
Strides
&
strides
)
__host__
__device__
constexpr
Layout
<
Shape
,
Strides
>
make_layout
(
const
Shape
&
shape
,
const
Strides
&
strides
)
{
{
return
Layout
<
Shape
,
Strides
>
(
shape
,
strides
);
return
Layout
<
Shape
,
Strides
>
(
shape
,
strides
);
}
}
template
<
typename
Shape
>
template
<
typename
Shape
>
Layout
<
Shape
>
make_layout
(
const
Shape
&
shape
)
__host__
__device__
constexpr
Layout
<
Shape
>
make_layout
(
const
Shape
&
shape
)
{
{
return
Layout
<
Shape
>
(
shape
);
return
Layout
<
Shape
>
(
shape
);
}
}
...
...
include/ck/utility/tuple_helper.hpp
View file @
81b79a77
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "functional4.hpp"
#include "functional4.hpp"
#include "tuple.hpp"
#include "tuple.hpp"
#include "is_detected.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -42,6 +43,13 @@ __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tup
...
@@ -42,6 +43,13 @@ __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tup
ty
);
ty
);
}
}
// Support any number of tuples to concat (also 1)
template
<
typename
...
X
>
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
)
{
return
tx
;
}
template
<
typename
...
X
,
typename
...
Tuples
>
template
<
typename
...
X
,
typename
...
Tuples
>
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
,
const
Tuples
&
...
tuples
)
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
,
const
Tuples
&
...
tuples
)
{
{
...
@@ -93,18 +101,69 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
...
@@ -93,18 +101,69 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
}
}
template
<
typename
T
>
// By default unroll to the flatten
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
__host__
__device__
constexpr
auto
UnrollNestedTuple
(
const
Tuple
<>&
element
)
{
return
element
;
}
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
,
typename
T
>
__host__
__device__
constexpr
auto
UnrollNestedTuple
(
const
T
&
element
)
__host__
__device__
constexpr
auto
UnrollNestedTuple
(
const
T
&
element
)
{
{
return
make_tuple
(
element
);
return
make_tuple
(
element
);
}
}
__host__
__device__
constexpr
auto
UnrollNestedTuple
(
const
Tuple
<>&
element
)
{
return
element
;
}
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
UnrollNestedTuple
(
const
Tuple
<
Ts
...
>&
tuple
)
{
if
constexpr
(
Depth
==
MaxDepth
)
{
return
tuple
;
}
else
{
return
unpack
(
[
&
](
auto
&&
...
ts
)
{
return
concat_tuple
(
UnrollNestedTuple
<
Depth
+
1
,
MaxDepth
>
(
ts
)...);
},
tuple
);
}
}
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
ReverseTuple
(
const
Tuple
<
Ts
...
>&
tuple
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
Idx
=
Number
<
Tuple
<
Ts
...
>::
Size
()
-
i
-
1
>
;
return
tuple
.
At
(
Idx
{});
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
}
// Reduce tuple values in specific range using Function
template
<
index_t
Idx
,
index_t
End
,
typename
F
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
TupleReduce
(
F
&&
f
,
const
Tuple
<
Ts
...
>&
tuple
)
{
static_assert
(
Idx
<
End
,
"Wrong parameters for TupleReduce"
);
if
constexpr
(
Idx
+
1
==
End
)
{
return
tuple
.
At
(
Number
<
Idx
>
{});
}
else
{
return
f
(
tuple
.
At
(
Number
<
Idx
>
{}),
TupleReduce
<
Idx
+
1
,
End
>
(
f
,
tuple
));
}
}
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
...
Ts
>
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
Unroll
Nested
Tuple
(
const
Tuple
<
Ts
...
>&
tuple
)
__host__
__device__
constexpr
auto
IsTuple
Nested
(
const
Tuple
<
Ts
...
>&
)
{
{
return
unpack
([
&
](
auto
&&
...
ts
)
{
return
concat_tuple
(
UnrollNestedTuple
(
ts
)...);
},
tuple
);
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...
);
}
}
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment