Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3e605990
Commit
3e605990
authored
Sep 12, 2022
by
Po-Yen, Chen
Browse files
Use helper functions to simplify example code
parent
a41132b5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
11 deletions
+53
-11
example/36_permute/common.hpp
example/36_permute/common.hpp
+47
-0
example/36_permute/run_permute_example.inc
example/36_permute/run_permute_example.inc
+6
-11
No files found.
example/36_permute/common.hpp
View file @
3e605990
...
...
@@ -87,6 +87,29 @@ struct get_bundled<F32, 2>
template
<
typename
Bundle
,
std
::
size_t
Divisor
>
using
get_bundled_t
=
typename
get_bundled
<
Bundle
,
Divisor
>::
type
;
template
<
typename
Array
,
std
::
size_t
Difference
>
struct
enlarge_array_size
;
template
<
typename
T
,
std
::
size_t
Size
,
std
::
size_t
Difference
>
struct
enlarge_array_size
<
std
::
array
<
T
,
Size
>
,
Difference
>
{
using
type
=
std
::
array
<
T
,
Size
+
Difference
>
;
};
template
<
typename
Array
,
std
::
size_t
Difference
>
using
enlarge_array_size_t
=
typename
enlarge_array_size
<
Array
,
Difference
>::
type
;
template
<
typename
Array
>
struct
get_array_size
;
template
<
typename
T
,
std
::
size_t
Size
>
struct
get_array_size
<
std
::
array
<
T
,
Size
>>
:
std
::
integral_constant
<
std
::
size_t
,
Size
>
{
};
template
<
typename
Array
>
inline
constexpr
std
::
size_t
get_array_size_v
=
get_array_size
<
Array
>::
value
;
template
<
typename
T
,
typename
=
void
>
struct
is_iterator
:
std
::
false_type
{
...
...
@@ -371,6 +394,30 @@ transpose_shape(const Shape& shape, const Axes& axes, OutputIterator iter)
return
iter
;
}
auto
extend_shape
(
const
Problem
::
Shape
&
shape
,
std
::
size_t
new_dim
)
{
detail
::
enlarge_array_size_t
<
Problem
::
Shape
,
1
>
extended_shape
;
using
std
::
begin
,
std
::
end
;
std
::
copy
(
begin
(
shape
),
end
(
shape
),
begin
(
extended_shape
));
extended_shape
.
back
()
=
new_dim
;
return
extended_shape
;
}
auto
extend_axes
(
const
Problem
::
Axes
&
axes
)
{
detail
::
enlarge_array_size_t
<
Problem
::
Axes
,
1
>
extended_axes
;
using
std
::
begin
,
std
::
end
;
std
::
copy
(
begin
(
axes
),
end
(
axes
),
begin
(
extended_axes
));
extended_axes
.
back
()
=
detail
::
get_array_size_v
<
Problem
::
Axes
>
;
return
extended_axes
;
}
template
<
typename
Shape
,
typename
Indices
>
std
::
enable_if_t
<
detail
::
is_bidirectional_range_v
<
Shape
>
&&
detail
::
is_sized_range_v
<
Shape
>
&&
detail
::
is_bidirectional_range_v
<
Indices
>
&&
detail
::
is_sized_range_v
<
Indices
>
,
...
...
example/36_permute/run_permute_example.inc
View file @
3e605990
...
...
@@ -78,24 +78,19 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
b
.
mData
,
host_b
.
mData
,
"Error: incorrect results in output tensor"
,
1
e
-
10
,
1
e
-
10
);
#else
// extend tensor shape from [N, H, W] to [N, H, W, NUM_ELEMS_IN_BUNDLE]
std
::
array
<
std
::
size_t
,
Problem
::
NumDim
+
1
>
extended_shape
;
std
::
copy
(
begin
(
shape
),
end
(
shape
),
begin
(
extended_shape
));
extended_shape
.
back
()
=
NUM_ELEMS_IN_BUNDLE
;
using
DataType
=
detail
::
get_bundled_t
<
ADataType
,
NUM_ELEMS_IN_BUNDLE
>
;
const
auto
extended_shape
=
extend_shape
(
shape
,
NUM_ELEMS_IN_BUNDLE
);
const
auto
extended_axes
=
extend_axes
(
problem
.
axes
);
ck
::
remove_cvref_t
<
decltype
(
extended_shape
)
>
transposed_extended_shape
;
transpose_shape
(
extended_shape
,
extended_axes
,
begin
(
transposed_extended_shape
));
Tensor
<
DataType
>
extended_a
(
extended_shape
);
std
::
memcpy
(
data
(
extended_a
.
mData
),
data
(
a
.
mData
),
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
std
::
array
<
std
::
size_t
,
Problem
::
NumDim
+
1
>
extended_axes
;
std
::
copy
(
begin
(
problem
.
axes
),
end
(
problem
.
axes
),
begin
(
extended_axes
));
extended_axes
.
back
()
=
Problem
::
NumDim
;
std
::
array
<
std
::
size_t
,
Problem
::
NumDim
+
1
>
transposed_extended_shape
;
transpose_shape
(
extended_shape
,
extended_axes
,
begin
(
transposed_extended_shape
));
Tensor
<
DataType
>
extended_host_b
(
transposed_extended_shape
);
if
(
!
host_permute
(
extended_a
,
extended_axes
,
PassThrough
{},
extended_host_b
))
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment