Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e43359fe
Commit
e43359fe
authored
Nov 29, 2023
by
Bartlomiej Kocot
Browse files
Add comments and remove not needed getters
parent
f124c7a1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
65 deletions
+19
-65
example/64_tensor_transforms/tensor_transform_wrapper.hpp
example/64_tensor_transforms/tensor_transform_wrapper.hpp
+19
-65
No files found.
example/64_tensor_transforms/tensor_transform_wrapper.hpp
View file @
e43359fe
...
@@ -72,6 +72,7 @@ struct Layout
...
@@ -72,6 +72,7 @@ struct Layout
{
{
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>>::
value
)
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>>::
value
)
{
{
// Return Sequence for the first tuple
constexpr
index_t
merge_nelems
=
decltype
(
UnrollNestedTuple
(
constexpr
index_t
merge_nelems
=
decltype
(
UnrollNestedTuple
(
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>
{}))
::
Size
();
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>
{}))
::
Size
();
using
LowerDimsSequence
=
using
LowerDimsSequence
=
...
@@ -80,11 +81,13 @@ struct Layout
...
@@ -80,11 +81,13 @@ struct Layout
}
}
else
else
{
{
// Return first element
return
Sequence
<
0
>
{};
return
Sequence
<
0
>
{};
}
}
}
}
else
else
{
{
// Get previous element using recurence (in compile-time)
using
PreviousSeqT
=
decltype
(
GenerateLowerDim
<
Number
<
Idx
::
value
-
1
>>
(
Tuple
<
Ts
...
>
{}));
using
PreviousSeqT
=
decltype
(
GenerateLowerDim
<
Number
<
Idx
::
value
-
1
>>
(
Tuple
<
Ts
...
>
{}));
const
auto
next_seq_val
=
PreviousSeqT
::
At
(
I0
)
+
1
;
const
auto
next_seq_val
=
PreviousSeqT
::
At
(
I0
)
+
1
;
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>>::
value
)
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
Idx
::
value
,
Tuple
<
Ts
...
>>>::
value
)
...
@@ -105,6 +108,9 @@ struct Layout
...
@@ -105,6 +108,9 @@ struct Layout
// Iterate over nested tuples in shape
// Iterate over nested tuples in shape
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
// Example idx: (1, 1), 1, 1
// Example shape: (2, (2, 2)), 2, (2, 2)
// Unrolled shape: 2, (2, 2), 2, (2, 2)
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
static
auto
UnrollShapeViaIdx
(
const
Tuple
<
ShapeDims
...
>&
shape
,
__host__
__device__
constexpr
static
auto
UnrollShapeViaIdx
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
)
const
Tuple
<
IdxDims
...
>&
idx
)
...
@@ -157,6 +163,11 @@ struct Layout
...
@@ -157,6 +163,11 @@ struct Layout
desc
,
make_tuple
(
make_merge_transform
(
merge_elems
)),
lower_dims
,
upper_dims
);
desc
,
make_tuple
(
make_merge_transform
(
merge_elems
)),
lower_dims
,
upper_dims
);
}
}
// Merge nested shape dims
// Input desc shape: 2, 2, 2, 2, 2, 2
// Example idx: 1, 1, 1, 1
// Example shape: 2, (2, 2), 2, (2, 2)
// Merged shape: 2, 4, 2, 4
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
,
typename
DescriptorToMerge
>
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
__host__
__device__
constexpr
static
auto
MakeMerges
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
,
DescriptorToMerge
&
desc
)
MakeMerges
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
,
DescriptorToMerge
&
desc
)
...
@@ -206,6 +217,10 @@ struct Layout
...
@@ -206,6 +217,10 @@ struct Layout
}
}
else
else
{
{
// Merge nested shape dims
// Example idx: (1, 1), 1, 1
// Example shape: (2, (2, 2)), 2, (2, 2)
// Merged shape: (2, 4), 2, 4
static_assert
(
Tuple
<
ShapeDims
...
>::
Size
()
==
Tuple
<
IdxDims
...
>::
Size
(),
static_assert
(
Tuple
<
ShapeDims
...
>::
Size
()
==
Tuple
<
IdxDims
...
>::
Size
(),
"Idx rank and Shape rank must be the same (except 1d)."
);
"Idx rank and Shape rank must be the same (except 1d)."
);
// Unroll while IdxDims is nested
// Unroll while IdxDims is nested
...
@@ -268,20 +283,6 @@ struct Layout
...
@@ -268,20 +283,6 @@ struct Layout
}
}
}
}
/**
* \brief Returns real offset to element as const in runtime.
*
* \tparam Idxs Tuple of indexes.
* \return Calculated offset as const.
*/
template
<
typename
Idxs
>
__host__
__device__
constexpr
index_t
operator
()()
const
{
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{}));
using
UnrolledIdx
=
decltype
(
UnrollNestedTuple
(
Idxs
{}));
return
TransformedDesc
{}.
CalculateOffset
(
UnrolledIdx
{});
}
/**
/**
* \brief Returns real offset to element in runtime.
* \brief Returns real offset to element in runtime.
*
*
...
@@ -289,7 +290,7 @@ struct Layout
...
@@ -289,7 +290,7 @@ struct Layout
* \return Calculated offset.
* \return Calculated offset.
*/
*/
template
<
typename
Idxs
>
template
<
typename
Idxs
>
__host__
__device__
constexpr
index_t
operator
()()
__host__
__device__
constexpr
index_t
operator
()()
const
{
{
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{}));
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{}));
using
UnrolledIdx
=
decltype
(
UnrollNestedTuple
(
Idxs
{}));
using
UnrolledIdx
=
decltype
(
UnrollNestedTuple
(
Idxs
{}));
...
@@ -310,28 +311,6 @@ struct Layout
...
@@ -310,28 +311,6 @@ struct Layout
return
transformed_desc
.
CalculateOffset
(
UnrollNestedTuple
(
Idx
));
return
transformed_desc
.
CalculateOffset
(
UnrollNestedTuple
(
Idx
));
}
}
/**
* \brief Length getter (product if tuple) as const.
*
* \tparam IDim Tuple of indexes or index.
* \return Calculated size.
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
GetLength
()
const
{
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
IDim
,
Shape
>>::
value
)
{
const
auto
unrolled_element
=
UnrollNestedTuple
(
elem
);
return
TupleReduce
<
I0
.
value
,
unrolled_element
.
Size
()
>
(
[](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
unrolled_element
);
}
else
{
return
elem
;
}
}
/**
/**
* \brief Length getter (product if tuple).
* \brief Length getter (product if tuple).
*
*
...
@@ -339,7 +318,7 @@ struct Layout
...
@@ -339,7 +318,7 @@ struct Layout
* \return Calculated size.
* \return Calculated size.
*/
*/
template
<
index_t
IDim
>
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
GetLength
()
__host__
__device__
constexpr
index_t
GetLength
()
const
{
{
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
IDim
,
Shape
>>::
value
)
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
IDim
,
Shape
>>::
value
)
...
@@ -354,43 +333,18 @@ struct Layout
...
@@ -354,43 +333,18 @@ struct Layout
}
}
}
}
/**
* \brief Layout size getter (product of shape) as const.
*
* \return Calculated size.
*/
__host__
__device__
constexpr
index_t
GetLength
()
const
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape_
);
return
TupleReduce
<
I0
.
value
,
unrolled_shape
.
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
unrolled_shape
);
}
/**
/**
* \brief Layout size getter (product of shape).
* \brief Layout size getter (product of shape).
*
*
* \return Calculated size.
* \return Calculated size.
*/
*/
__host__
__device__
constexpr
index_t
GetLength
()
__host__
__device__
constexpr
index_t
GetLength
()
const
{
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape_
);
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape_
);
return
TupleReduce
<
I0
.
value
,
unrolled_shape
.
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
return
TupleReduce
<
I0
.
value
,
unrolled_shape
.
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
unrolled_shape
);
unrolled_shape
);
}
}
/**
* \brief Dimension getter as const.
*
* \tparam IDim Dimension idx.
* \return Calculated size.
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
auto
Get
()
const
{
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
return
elem
;
}
/**
/**
* \brief Dimension getter.
* \brief Dimension getter.
*
*
...
@@ -398,7 +352,7 @@ struct Layout
...
@@ -398,7 +352,7 @@ struct Layout
* \return Calculated size.
* \return Calculated size.
*/
*/
template
<
index_t
IDim
>
template
<
index_t
IDim
>
__host__
__device__
constexpr
auto
Get
()
__host__
__device__
constexpr
auto
Get
()
const
{
{
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
return
elem
;
return
elem
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment