Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8b98d7d2
Commit
8b98d7d2
authored
Sep 12, 2022
by
Po-Yen, Chen
Browse files
Add support to permute multiple elements together
parent
5f50ed89
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
82 additions
and
22 deletions
+82
-22
example/36_permute/CMakeLists.txt
example/36_permute/CMakeLists.txt
+2
-2
example/36_permute/common.hpp
example/36_permute/common.hpp
+29
-10
example/36_permute/permute_HxWx2_fp16.cpp
example/36_permute/permute_HxWx2_fp16.cpp
+4
-7
example/36_permute/run_permute_example.inc
example/36_permute/run_permute_example.inc
+47
-3
No files found.
example/36_permute/CMakeLists.txt
View file @
8b98d7d2
...
@@ -2,8 +2,8 @@ add_custom_target(example_permute)
...
@@ -2,8 +2,8 @@ add_custom_target(example_permute)
add_example_executable
(
example_permute_1xHxW_fp32 permute_1xHxW_fp32.cpp
)
add_example_executable
(
example_permute_1xHxW_fp32 permute_1xHxW_fp32.cpp
)
add_example_executable
(
example_permute_NxHxW_fp32 permute_NxHxW_fp32.cpp
)
add_example_executable
(
example_permute_NxHxW_fp32 permute_NxHxW_fp32.cpp
)
add_example_executable
(
example_permute_HxWx
4
_fp16 permute_HxWx
4
_fp16.cpp
)
add_example_executable
(
example_permute_HxWx
2
_fp16 permute_HxWx
2
_fp16.cpp
)
add_dependencies
(
example_permute example_permute_1xHxW_fp32
)
add_dependencies
(
example_permute example_permute_1xHxW_fp32
)
add_dependencies
(
example_permute example_permute_NxHxW_fp32
)
add_dependencies
(
example_permute example_permute_NxHxW_fp32
)
add_dependencies
(
example_permute example_permute_HxWx
4
_fp16
)
add_dependencies
(
example_permute example_permute_HxWx
2
_fp16
)
example/36_permute/common.hpp
View file @
8b98d7d2
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <cassert>
#include <cassert>
#include <cstddef>
#include <cstddef>
#include <cstdlib>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iostream>
#include <iterator>
#include <iterator>
#include <numeric>
#include <numeric>
...
@@ -33,7 +34,9 @@ struct ExecutionConfig final
...
@@ -33,7 +34,9 @@ struct ExecutionConfig final
struct
Problem
final
struct
Problem
final
{
{
using
Shape
=
std
::
array
<
std
::
size_t
,
3
>
;
static
constexpr
std
::
size_t
NumDim
=
3
;
using
Shape
=
std
::
array
<
std
::
size_t
,
NumDim
>
;
using
Axes
=
Shape
;
using
Axes
=
Shape
;
Problem
()
=
delete
;
Problem
()
=
delete
;
...
@@ -249,7 +252,7 @@ is_valid_axes(const Axes& axes)
...
@@ -249,7 +252,7 @@ is_valid_axes(const Axes& axes)
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ExecutionConfig
&
config
,
Problem
&
problem
)
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ExecutionConfig
&
config
,
Problem
&
problem
)
{
{
constexpr
int
num_execution_config_args
=
2
;
constexpr
int
num_execution_config_args
=
2
;
constexpr
int
num_problem_args
=
3
+
3
;
constexpr
int
num_problem_args
=
2
*
Problem
::
NumDim
;
if
(
!
(
num_problem_args
==
size
(
problem
.
shape
)
+
size
(
problem
.
axes
)))
if
(
!
(
num_problem_args
==
size
(
problem
.
shape
)
+
size
(
problem
.
axes
)))
{
{
...
@@ -412,7 +415,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
...
@@ -412,7 +415,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
}
}
using
std
::
size
;
using
std
::
size
;
if
(
!
(
is_valid_axes
(
axes
)
&&
size
(
axes
)
==
3
)
)
if
(
!
is_valid_axes
(
axes
))
{
{
return
false
;
return
false
;
}
}
...
@@ -420,7 +423,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
...
@@ -420,7 +423,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
static_assert
(
detail
::
is_sized_range_v
<
ck
::
remove_cvref_t
<
decltype
(
shape
)
>>
&&
static_assert
(
detail
::
is_sized_range_v
<
ck
::
remove_cvref_t
<
decltype
(
shape
)
>>
&&
detail
::
is_sized_range_v
<
ck
::
remove_cvref_t
<
decltype
(
transposed_shape
)
>>
);
detail
::
is_sized_range_v
<
ck
::
remove_cvref_t
<
decltype
(
transposed_shape
)
>>
);
if
(
!
(
size
(
shape
)
=
=
3
&&
size
(
transposed_shape
)
==
3
)
)
if
(
size
(
shape
)
!
=
size
(
transposed_shape
))
{
{
return
false
;
return
false
;
}
}
...
@@ -437,18 +440,34 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
...
@@ -437,18 +440,34 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
}
}
}
}
std
::
array
<
std
::
size_t
,
3
>
indices
{}
;
std
::
vector
<
std
::
size_t
>
indices
(
size
(
shape
),
0
)
;
if
(
!
is_valid_indices
(
shape
,
indices
))
if
(
!
is_valid_indices
(
shape
,
indices
))
{
{
return
false
;
return
false
;
}
}
do
if
(
size
(
shape
)
==
3
)
{
do
{
Dest
output
=
0
;
functor
(
output
,
src
(
indices
[
0
],
indices
[
1
],
indices
[
2
]));
dest
(
indices
[
axes
[
0
]],
indices
[
axes
[
1
]],
indices
[
axes
[
2
]])
=
output
;
}
while
(
advance_indices
(
shape
,
indices
));
}
else
if
(
size
(
shape
)
==
4
)
{
do
{
Dest
output
=
0
;
functor
(
output
,
src
(
indices
[
0
],
indices
[
1
],
indices
[
2
],
indices
[
3
]));
dest
(
indices
[
axes
[
0
]],
indices
[
axes
[
1
]],
indices
[
axes
[
2
]],
indices
[
axes
[
3
]])
=
output
;
}
while
(
advance_indices
(
shape
,
indices
));
}
else
{
{
Dest
output
=
0
;
return
false
;
functor
(
output
,
src
(
indices
[
0
],
indices
[
1
],
indices
[
2
]));
}
dest
(
indices
[
axes
[
0
]],
indices
[
axes
[
1
]],
indices
[
axes
[
2
]])
=
output
;
}
while
(
advance_indices
(
shape
,
indices
));
return
true
;
return
true
;
}
}
example/36_permute/permute_HxWx
4
_fp16.cpp
→
example/36_permute/permute_HxWx
2
_fp16.cpp
View file @
8b98d7d2
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
#include "common.hpp"
#include "common.hpp"
using
ADataType
=
F
64
;
using
ADataType
=
F
32
;
using
BDataType
=
F
64
;
using
BDataType
=
F
32
;
// clang-format off
// clang-format off
using
DevicePermuteInstance
=
ck
::
tensor_operation
::
device
::
DevicePermute
using
DevicePermuteInstance
=
ck
::
tensor_operation
::
device
::
DevicePermute
...
@@ -15,10 +15,7 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
...
@@ -15,10 +15,7 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
<
ADataType
,
BDataType
,
PassThrough
,
3
,
256
,
128
,
128
,
0
,
S
<
1
,
16
,
16
>
,
S
<
0
,
1
,
2
>
,
2
,
1
,
1
,
1
>
;
<
ADataType
,
BDataType
,
PassThrough
,
3
,
256
,
128
,
128
,
0
,
S
<
1
,
16
,
16
>
,
S
<
0
,
1
,
2
>
,
2
,
1
,
1
,
1
>
;
// clang-format on
// clang-format on
#define NUM_ELEMS_IN_BUNDLE
4
#define NUM_ELEMS_IN_BUNDLE
2
#include "run_permute_example.inc"
#include "run_permute_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_permute_example
(
argc
,
argv
,
{
1
,
3
,
4
},
{
0
,
2
,
1
});
}
{
return
!
run_permute_example
(
argc
,
argv
,
{
1
,
160
,
80
},
{
0
,
2
,
1
});
}
example/36_permute/run_permute_example.inc
View file @
8b98d7d2
...
@@ -9,6 +9,10 @@
...
@@ -9,6 +9,10 @@
bool
run_permute
(
const
ExecutionConfig
&
config
,
const
Problem
&
problem
)
bool
run_permute
(
const
ExecutionConfig
&
config
,
const
Problem
&
problem
)
{
{
#if 1 < NUM_ELEMS_IN_BUNDLE
static_assert
(
std
::
is_same_v
<
ADataType
,
BDataType
>
);
#endif
using
std
::
begin
,
std
::
end
;
using
std
::
begin
,
std
::
end
;
const
auto
&
shape
=
problem
.
shape
;
const
auto
&
shape
=
problem
.
shape
;
...
@@ -61,13 +65,53 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
...
@@ -61,13 +65,53 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
{
{
Tensor
<
BDataType
>
host_b
(
transposed_shape
);
host_permute
(
a
,
problem
.
axes
,
PassThrough
{},
host_b
);
b_device_buf
.
FromDevice
(
data
(
b
.
mData
));
b_device_buf
.
FromDevice
(
data
(
b
.
mData
));
#if NUM_ELEMS_IN_BUNDLE == 1
Tensor
<
BDataType
>
host_b
(
transposed_shape
);
if
(
!
host_permute
(
a
,
problem
.
axes
,
PassThrough
{},
host_b
))
{
return
false
;
}
return
ck
::
utils
::
check_err
(
return
ck
::
utils
::
check_err
(
b
.
mData
,
host_b
.
mData
,
"Error: incorrect results in output tensor"
,
1
e
-
10
,
1
e
-
10
);
b
.
mData
,
host_b
.
mData
,
"Error: incorrect results in output tensor"
,
1
e
-
10
,
1
e
-
10
);
#else
// extend tensor shape from [N, H, W] to [N, H, W, NUM_ELEMS_IN_BUNDLE]
std
::
array
<
std
::
size_t
,
Problem
::
NumDim
+
1
>
extended_shape
;
std
::
copy
(
begin
(
shape
),
end
(
shape
),
begin
(
extended_shape
));
extended_shape
.
back
()
=
NUM_ELEMS_IN_BUNDLE
;
using
DataType
=
detail
::
get_bundled_t
<
ADataType
,
NUM_ELEMS_IN_BUNDLE
>
;
Tensor
<
DataType
>
extended_a
(
extended_shape
);
std
::
memcpy
(
data
(
extended_a
.
mData
),
data
(
a
.
mData
),
sizeof
(
ADataType
)
*
a
.
mDesc
.
GetElementSpaceSize
());
std
::
array
<
std
::
size_t
,
Problem
::
NumDim
+
1
>
extended_axes
;
std
::
copy
(
begin
(
problem
.
axes
),
end
(
problem
.
axes
),
begin
(
extended_axes
));
extended_axes
.
back
()
=
Problem
::
NumDim
;
std
::
array
<
std
::
size_t
,
Problem
::
NumDim
+
1
>
transposed_extended_shape
;
transpose_shape
(
extended_shape
,
extended_axes
,
begin
(
transposed_extended_shape
));
Tensor
<
DataType
>
extended_host_b
(
transposed_extended_shape
);
if
(
!
host_permute
(
extended_a
,
extended_axes
,
PassThrough
{},
extended_host_b
))
{
return
false
;
}
std
::
vector
<
DataType
>
extended_b
(
reinterpret_cast
<
DataType
*>
(
data
(
b
.
mData
)),
reinterpret_cast
<
DataType
*>
(
data
(
b
.
mData
))
+
b
.
mDesc
.
GetElementSpaceSize
()
*
NUM_ELEMS_IN_BUNDLE
);
return
ck
::
utils
::
check_err
(
extended_b
,
extended_host_b
.
mData
,
"Error: incorrect results in output tensor"
,
1
e
-
10
,
1
e
-
10
);
#endif
}
}
return
true
;
return
true
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment