Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
632bfff0
Commit
632bfff0
authored
Sep 12, 2022
by
Po-Yen, Chen
Browse files
Specify problem in each examples
parent
1989df75
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
20 deletions
+38
-20
example/36_permute/common.hpp
example/36_permute/common.hpp
+24
-12
example/36_permute/permute_fp16.cpp
example/36_permute/permute_fp16.cpp
+7
-4
example/36_permute/run_permute_example.inc
example/36_permute/run_permute_example.inc
+7
-4
No files found.
example/36_permute/common.hpp
View file @
632bfff0
...
...
@@ -22,6 +22,7 @@
#include "ck/library/utility/host_tensor_generator.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
struct
ExecutionConfig
final
{
...
...
@@ -31,8 +32,18 @@ struct ExecutionConfig final
struct
Problem
final
{
std
::
array
<
std
::
size_t
,
4
>
shape
=
{
4
,
8
,
16
,
32
};
std
::
array
<
std
::
size_t
,
4
>
axes
=
{
0
,
1
,
3
,
2
};
using
Shape
=
std
::
array
<
std
::
size_t
,
3
>
;
using
Axes
=
Shape
;
Problem
()
=
delete
;
explicit
Problem
(
const
Shape
&
default_shape
,
const
Axes
&
default_axes
)
:
shape
(
default_shape
),
axes
(
default_axes
)
{
}
Shape
shape
;
Axes
axes
;
};
template
<
ck
::
index_t
...
Is
>
...
...
@@ -207,7 +218,7 @@ is_valid_axes(const Axes& axes)
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ExecutionConfig
&
config
,
Problem
&
problem
)
{
constexpr
int
num_execution_config_args
=
2
;
constexpr
int
num_problem_args
=
8
;
constexpr
int
num_problem_args
=
3
+
3
;
if
(
!
(
num_problem_args
==
size
(
problem
.
shape
)
+
size
(
problem
.
axes
)))
{
...
...
@@ -231,13 +242,14 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob
// read shape
for
(
std
::
size_t
idx
=
0
;
idx
<
size
(
problem
.
shape
);
++
idx
)
{
problem
.
shape
[
idx
]
=
std
::
stoi
(
argv
[
idx
+
3
]);
problem
.
shape
[
idx
]
=
std
::
stoi
(
argv
[
idx
+
(
1
+
num_execution_config_args
)
]);
}
// read axes
for
(
std
::
size_t
idx
=
0
;
idx
<
size
(
problem
.
axes
);
++
idx
)
{
problem
.
axes
[
idx
]
=
std
::
stoi
(
argv
[
idx
+
size
(
problem
.
shape
)
+
3
]);
problem
.
axes
[
idx
]
=
std
::
stoi
(
argv
[
idx
+
(
1
+
num_execution_config_args
+
size
(
problem
.
shape
))]);
}
if
(
!
is_valid_axes
(
problem
.
axes
))
...
...
@@ -254,8 +266,8 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
<<
"arg2: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3 ~ arg
6
: shape for
4
D tensor"
<<
std
::
endl
<<
"arg
7
~ arg
10
: axes to permute"
<<
std
::
endl
;
<<
"arg3 ~ arg
5
: shape for
3
D tensor"
<<
std
::
endl
<<
"arg
6
~ arg
8
: axes to permute"
<<
std
::
endl
;
return
false
;
}
...
...
@@ -369,7 +381,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
}
using
std
::
size
;
if
(
!
(
is_valid_axes
(
axes
)
&&
size
(
axes
)
==
4
))
if
(
!
(
is_valid_axes
(
axes
)
&&
size
(
axes
)
==
3
))
{
return
false
;
}
...
...
@@ -377,7 +389,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
static_assert
(
detail
::
is_sized_range_v
<
ck
::
remove_cvref_t
<
decltype
(
shape
)
>>
&&
detail
::
is_sized_range_v
<
ck
::
remove_cvref_t
<
decltype
(
transposed_shape
)
>>
);
if
(
!
(
size
(
shape
)
==
4
&&
size
(
transposed_shape
)
==
4
))
if
(
!
(
size
(
shape
)
==
3
&&
size
(
transposed_shape
)
==
3
))
{
return
false
;
}
...
...
@@ -394,7 +406,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
}
}
std
::
array
<
std
::
size_t
,
4
>
indices
{};
std
::
array
<
std
::
size_t
,
3
>
indices
{};
if
(
!
is_valid_indices
(
shape
,
indices
))
{
return
false
;
...
...
@@ -403,8 +415,8 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
do
{
Dest
output
=
0
;
functor
(
output
,
src
(
indices
[
0
],
indices
[
1
],
indices
[
2
]
,
indices
[
3
]
));
dest
(
indices
[
axes
[
0
]],
indices
[
axes
[
1
]],
indices
[
axes
[
2
]]
,
indices
[
axes
[
3
]]
)
=
output
;
functor
(
output
,
src
(
indices
[
0
],
indices
[
1
],
indices
[
2
]));
dest
(
indices
[
axes
[
0
]],
indices
[
axes
[
1
]],
indices
[
axes
[
2
]])
=
output
;
}
while
(
advance_indices
(
shape
,
indices
));
return
true
;
...
...
example/36_permute/permute_fp16.cpp
View file @
632bfff0
...
...
@@ -3,8 +3,8 @@
#include "common.hpp"
using
ADataType
=
F
16
;
using
BDataType
=
F
16
;
using
ADataType
=
F
32
;
using
BDataType
=
F
32
;
// clang-format off
using
DevicePermuteInstance
=
ck
::
tensor_operation
::
device
::
DevicePermute
...
...
@@ -12,9 +12,12 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
// ######| Type| Type| Operation| | Size| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
PassThrough
,
4
,
256
,
128
,
128
,
0
,
S
<
1
,
16
,
16
>
,
S
<
0
,
1
,
2
>
,
3
,
2
,
1
,
1
>
;
<
ADataType
,
BDataType
,
PassThrough
,
3
,
256
,
128
,
128
,
0
,
S
<
1
,
16
,
16
>
,
S
<
0
,
1
,
2
>
,
2
,
1
,
1
,
1
>
;
// clang-format on
#include "run_permute_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_permute_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_permute_example
(
argc
,
argv
,
{
1
,
16000
,
80
},
{
0
,
2
,
1
});
}
example/36_permute/run_permute_example.inc
View file @
632bfff0
...
...
@@ -21,8 +21,8 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
a_device_buf
.
ToDevice
(
a
.
mData
.
data
());
std
::
array
<
ck
::
index_t
,
4
>
a_lengths
,
b_lengths
;
std
::
array
<
ck
::
index_t
,
4
>
a_strides
,
b_strides
;
std
::
array
<
ck
::
index_t
,
3
>
a_lengths
,
b_lengths
;
std
::
array
<
ck
::
index_t
,
3
>
a_strides
,
b_strides
;
const
void
*
input
=
a_device_buf
.
GetDeviceBuffer
();
void
*
output
=
b_device_buf
.
GetDeviceBuffer
();
...
...
@@ -64,10 +64,13 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
return
true
;
}
bool
run_permute_example
(
int
argc
,
char
*
argv
[])
bool
run_permute_example
(
int
argc
,
char
*
argv
[],
const
Problem
::
Shape
&
default_shape
,
const
Problem
::
Axes
&
default_axes
)
{
ExecutionConfig
config
;
Problem
problem
;
Problem
problem
(
default_shape
,
default_axes
)
;
return
parse_cmd_args
(
argc
,
argv
,
config
,
problem
)
&&
run_permute
(
config
,
problem
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment