Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f124c7a1
Commit
f124c7a1
authored
Nov 28, 2023
by
Bartlomiej Kocot
Browse files
Comment fixes
parent
936b1d6c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
19 deletions
+24
-19
example/64_tensor_transforms/tensor_transform.cpp
example/64_tensor_transforms/tensor_transform.cpp
+4
-4
example/64_tensor_transforms/tensor_transform_using_wrapper.cpp
...e/64_tensor_transforms/tensor_transform_using_wrapper.cpp
+3
-3
example/64_tensor_transforms/tensor_transform_wrapper.hpp
example/64_tensor_transforms/tensor_transform_wrapper.hpp
+15
-10
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+2
-2
No files found.
example/64_tensor_transforms/tensor_transform.cpp
View file @
f124c7a1
...
...
@@ -74,15 +74,15 @@ int main()
std
::
cout
<<
"dims:4,8 strides:1,4"
<<
std
::
endl
;
Print2d
(
desc_4x8_s1x4
);
using
Cord
0x0
Type
=
ck
::
Tuple
<
ck
::
Number
<
0
>
,
ck
::
Number
<
0
>>
;
constexpr
ck
::
index_t
offset_
0x0
=
desc_4x8_s1x4
.
CalculateOffset
(
Cord
0x0
Type
{});
std
::
cout
<<
"Constexpr calculated [
0
,
0
] offset:"
<<
offset_
0x0
<<
std
::
endl
;
using
Cord
1x1
Type
=
ck
::
Tuple
<
ck
::
Number
<
1
>
,
ck
::
Number
<
1
>>
;
constexpr
ck
::
index_t
offset_
1x1
=
desc_4x8_s1x4
.
CalculateOffset
(
Cord
1x1
Type
{});
std
::
cout
<<
"Constexpr calculated [
1
,
1
] offset:"
<<
offset_
1x1
<<
std
::
endl
;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:4,(2,4) strides:2,(1,8)
const
auto
desc_4x2x4_s2x1x8
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
4
,
2
,
4
),
ck
::
make_tuple
(
2
,
1
,
8
));
// Transform to 2d
// Transform to 2d
(column-major, need to to reverse dims)
const
auto
desc_4x2x4_s2x1x8_merged
=
ck
::
transform_tensor_descriptor
(
desc_4x2x4_s2x1x8
,
ck
::
make_tuple
(
ck
::
make_pass_through_transform
(
4
),
...
...
example/64_tensor_transforms/tensor_transform_using_wrapper.cpp
View file @
f124c7a1
...
...
@@ -71,9 +71,9 @@ int main()
const
auto
layout_4x8_s1x4
=
ck
::
tensor_transform_wrapper
::
make_layout
(
shape_4x8
);
std
::
cout
<<
"dims:4,8 strides:1,4"
<<
std
::
endl
;
Print2d
(
layout_4x8_s1x4
);
using
Cord
0x0
Type
=
ck
::
Tuple
<
ck
::
Number
<
0
>
,
ck
::
Number
<
0
>>
;
constexpr
ck
::
index_t
offset_
0x0
=
layout_4x8_s1x4
.
template
operator
()
<
Cord
0x0
Type
>();
std
::
cout
<<
"Constexpr calculated [
0
,
0
] offset:"
<<
offset_
0x0
<<
std
::
endl
;
using
Cord
1x1
Type
=
ck
::
Tuple
<
ck
::
Number
<
1
>
,
ck
::
Number
<
1
>>
;
constexpr
ck
::
index_t
offset_
1x1
=
layout_4x8_s1x4
.
template
operator
()
<
Cord
1x1
Type
>();
std
::
cout
<<
"Constexpr calculated [
1
,
1
] offset:"
<<
offset_
1x1
<<
std
::
endl
;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor)
// dims:4,(2,4) strides:2,(1,8)
...
...
example/64_tensor_transforms/tensor_transform_wrapper.hpp
View file @
f124c7a1
...
...
@@ -62,15 +62,18 @@ struct Layout
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
}
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
// If tuple is element, then pass through (sequence with one element)
template
<
typename
Idx
,
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateLowerDim
(
const
Tuple
<
Ts
...
>&
tuple
)
__host__
__device__
constexpr
static
auto
GenerateLowerDim
(
const
Tuple
<
Ts
...
>&
)
{
if
constexpr
(
Idx
::
value
==
0
)
{
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>>::
value
)
{
constexpr
index_t
merge_nelems
=
decltype
(
UnrollNestedTuple
(
tuple
.
At
(
Idx
{}))
)
::
Size
();
constexpr
index_t
merge_nelems
=
decltype
(
UnrollNestedTuple
(
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>
{}))
::
Size
();
using
LowerDimsSequence
=
typename
arithmetic_sequence_gen
<
0
,
merge_nelems
,
1
>::
type
;
return
LowerDimsSequence
::
Reverse
();
...
...
@@ -82,12 +85,12 @@ struct Layout
}
else
{
using
PreviousSeqT
=
decltype
(
GenerateLowerDim
<
Number
<
Idx
::
value
-
1
>>
(
t
uple
));
using
PreviousSeqT
=
decltype
(
GenerateLowerDim
<
Number
<
Idx
::
value
-
1
>>
(
T
uple
<
Ts
...
>
{}
));
const
auto
next_seq_val
=
PreviousSeqT
::
At
(
I0
)
+
1
;
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>>::
value
)
{
constexpr
index_t
merge_nelems
=
decltype
(
UnrollNestedTuple
(
tuple
.
At
(
Idx
{}))
)
::
Size
();
constexpr
index_t
merge_nelems
=
decltype
(
UnrollNestedTuple
(
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>
{}))
::
Size
();
using
LowerDimsSequence
=
typename
arithmetic_sequence_gen
<
next_seq_val
,
next_seq_val
+
merge_nelems
,
1
>::
type
;
...
...
@@ -100,11 +103,13 @@ struct Layout
}
}
// Iterate over nested tuples in shape
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
static
auto
UnrollShapeViaIdx
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
)
{
if
constexpr
(
!
Is
Tuple
Nested
(
Tuple
<
IdxDims
...
>
{}))
if
constexpr
(
!
IsNested
Tuple
(
Tuple
<
IdxDims
...
>
{}))
{
// Index unrolled to flatten, return shape
return
shape
;
...
...
@@ -112,7 +117,7 @@ struct Layout
else
{
// Iterate over shape tuple elements:
// 1. If core
s
sponding idx element is tuple then return (will be unrolled)
// 1. If cor
r
esponding idx element is tuple then return (will be unrolled)
// 2. If no, pack in tuple. It will be restored during unroll.
auto
unrolled_shape_via_idx
=
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -139,7 +144,7 @@ struct Layout
DescriptorToMerge
&
desc
)
{
// Reverse each element in tuple
using
ReversedUnrolledShape
=
decltype
(
Reverse
Tuple
(
UnrollNestedTuple
(
shape
)));
using
ReversedUnrolledShape
=
decltype
(
Tuple
Reverse
(
UnrollNestedTuple
(
shape
)));
const
auto
merge_elems
=
ReversedUnrolledShape
{};
// Generate reverted indexes (column major traverse)
...
...
@@ -165,7 +170,7 @@ struct Layout
{
// If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major
const
auto
merge_elems
=
Reverse
Tuple
(
UnrollNestedTuple
(
shape
.
At
(
i
)));
const
auto
merge_elems
=
Tuple
Reverse
(
UnrollNestedTuple
(
shape
.
At
(
i
)));
return
make_merge_transform
(
merge_elems
);
}
else
...
...
include/ck/utility/tuple_helper.hpp
View file @
f124c7a1
...
...
@@ -132,7 +132,7 @@ __host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
}
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
Reverse
Tuple
(
const
Tuple
<
Ts
...
>&
tuple
)
__host__
__device__
constexpr
auto
Tuple
Reverse
(
const
Tuple
<
Ts
...
>&
tuple
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -161,7 +161,7 @@ template <typename T>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
Is
Tuple
Nested
(
const
Tuple
<
Ts
...
>&
)
__host__
__device__
constexpr
auto
IsNested
Tuple
(
const
Tuple
<
Ts
...
>&
)
{
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment