Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7ebb1cbf
Commit
7ebb1cbf
authored
Sep 06, 2022
by
Po-Yen, Chen
Browse files
Add checks in helper functions
parent
e1f959fd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
66 additions
and
18 deletions
+66
-18
example/36_elementwise_permute/common.hpp
example/36_elementwise_permute/common.hpp
+65
-17
example/36_elementwise_permute/run_elementwise_permute_example.inc
...6_elementwise_permute/run_elementwise_permute_example.inc
+1
-1
No files found.
example/36_elementwise_permute/common.hpp
View file @
7ebb1cbf
...
...
@@ -203,7 +203,10 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob
constexpr
int
num_execution_config_args
=
2
;
constexpr
int
num_problem_args
=
8
;
assert
(
num_problem_args
==
size
(
problem
.
shape
)
+
size
(
problem
.
axes
));
if
(
!
(
num_problem_args
==
size
(
problem
.
shape
)
+
size
(
problem
.
axes
)))
{
return
false
;
}
if
(
argc
==
1
)
{
...
...
@@ -265,7 +268,10 @@ template <typename Shape, typename Indices>
inline
std
::
enable_if_t
<
detail
::
is_sized_range_v
<
Shape
>
&&
detail
::
is_sized_range_v
<
Indices
>
,
bool
>
is_valid_indices
(
const
Shape
&
shape
,
const
Indices
&
indices
)
{
assert
(
is_valid_shape
(
shape
));
if
(
!
is_valid_shape
(
shape
))
{
return
false
;
}
using
std
::
empty
;
if
(
empty
(
indices
))
...
...
@@ -320,9 +326,10 @@ std::enable_if_t<detail::is_bidirectional_range_v<Shape> && detail::is_sized_ran
advance_indices
(
const
Shape
&
shape
,
Indices
&
indices
)
{
using
std
::
size
;
assert
(
is_valid_shape
(
shape
));
assert
(
is_valid_indices
(
indices
));
assert
(
size
(
shape
)
==
size
(
indices
));
if
(
!
(
is_valid_shape
(
shape
)
&&
is_valid_indices
(
shape
,
indices
)
&&
size
(
shape
)
==
size
(
indices
)))
{
return
false
;
}
bool
carry
=
true
;
...
...
@@ -340,29 +347,70 @@ advance_indices(const Shape& shape, Indices& indices)
return
!
carry
;
}
template
<
typename
Src
,
typename
Functor
,
typename
Dest
>
std
::
enable_if_t
<
std
::
is_invocable_v
<
Functor
,
std
::
add_lvalue_reference_t
<
Dest
>
,
std
::
add_lvalue_reference_t
<
Src
>>>
host_elementwise_permute
(
const
Tensor
<
Src
>&
src
,
Functor
functor
,
Tensor
<
Dest
>&
dest
)
template
<
typename
Src
,
typename
Axes
,
typename
Functor
,
typename
Dest
>
std
::
enable_if_t
<
detail
::
is_random_access_range_v
<
Axes
>
&&
detail
::
is_sized_range_v
<
Axes
>
&&
std
::
is_invocable_v
<
Functor
,
std
::
add_lvalue_reference_t
<
Dest
>
,
std
::
add_lvalue_reference_t
<
Src
>>
,
bool
>
host_elementwise_permute
(
const
Tensor
<
Src
>&
src
,
const
Axes
&
axes
,
Functor
functor
,
Tensor
<
Dest
>&
dest
)
{
const
auto
&
shape
=
src
.
mDesc
.
GetLengths
();
const
auto
&
transposed_shape
=
dest
.
mDesc
.
GetLengths
();
assert
(
is_valid_shape
(
shape
)
&&
is_valid_shape
(
transposed_shape
));
std
::
copy
(
begin
(
shape
),
end
(
shape
),
std
::
ostream_iterator
<
std
::
size_t
>
(
std
::
cerr
,
" "
));
std
::
cerr
<<
std
::
endl
;
std
::
copy
(
begin
(
transposed_shape
),
end
(
transposed_shape
),
std
::
ostream_iterator
<
std
::
size_t
>
(
std
::
cerr
,
" "
));
std
::
cerr
<<
std
::
endl
;
if
(
!
(
is_valid_shape
(
shape
)
&&
is_valid_shape
(
transposed_shape
)))
{
return
false
;
}
using
std
::
size
;
if
(
!
(
is_valid_axes
(
axes
)
&&
size
(
axes
)
==
4
))
{
return
false
;
}
static_assert
(
detail
::
is_sized_range_v
<
ck
::
remove_cvref_t
<
decltype
(
shape
)
>>
&&
detail
::
is_sized_range_v
<
ck
::
remove_cvref_t
<
decltype
(
transposed_shape
)
>>
);
using
std
::
size
;
assert
(
size
(
shape
)
==
4
&&
size
(
transposed_shape
)
==
4
);
if
(
!
(
size
(
shape
)
==
4
&&
size
(
transposed_shape
)
==
4
))
{
return
false
;
}
static_assert
(
detail
::
is_random_access_range_v
<
ck
::
remove_cvref_t
<
decltype
(
shape
)
>>
&&
detail
::
is_random_access_range_v
<
ck
::
remove_cvref_t
<
decltype
(
transposed_shape
)
>>
);
{
for
(
std
::
size_t
idx
=
0
;
idx
<
size
(
shape
);
++
idx
)
{
if
(
transposed_shape
[
idx
]
!=
shape
[
axes
[
idx
]])
{
return
false
;
}
}
}
std
::
array
<
std
::
size_t
,
4
>
indices
{};
assert
(
is_valid_indices
(
indices
));
if
(
!
is_valid_indices
(
shape
,
indices
))
{
return
false
;
}
do
{
Dest
b_val
=
0
;
functor
(
b_val
,
src
(
indices
[
0
],
indices
[
1
],
indices
[
2
],
indices
[
3
]));
dest
(
indices
[
0
],
indices
[
2
],
indices
[
3
],
indices
[
1
])
=
b_val
;
Dest
output
=
0
;
functor
(
output
,
src
(
indices
[
0
],
indices
[
1
],
indices
[
2
],
indices
[
3
]));
dest
(
indices
[
axes
[
0
]
],
indices
[
axes
[
1
]
],
indices
[
axes
[
2
]
],
indices
[
axes
[
3
]])
=
output
;
}
while
(
advance_indices
(
shape
,
indices
));
return
true
;
}
example/36_elementwise_permute/run_elementwise_permute_example.inc
View file @
7ebb1cbf
...
...
@@ -49,7 +49,7 @@ bool run_elementwise_permute(const ExecutionConfig& config, const Problem& probl
if
(
config
.
do_verification
)
{
Tensor
<
BDataType
>
host_b
(
nhwc
);
host_elementwise_permute
(
a
,
PassThrough
{},
host_b
);
host_elementwise_permute
(
a
,
problem
.
axes
,
PassThrough
{},
host_b
);
b_device_buf
.
FromDevice
(
b
.
mData
.
data
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment