Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5b63400a
Commit
5b63400a
authored
Sep 06, 2022
by
Po-Yen, Chen
Browse files
Passing 'axes' to 'DevicePermute'
parent
50f5ce49
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
11 deletions
+12
-11
example/36_permute/run_permute_example.inc
example/36_permute/run_permute_example.inc
+6
-5
include/ck/tensor_operation/gpu/device/device_permute.hpp
include/ck/tensor_operation/gpu/device/device_permute.hpp
+6
-6
No files found.
example/36_permute/run_permute_example.inc
View file @
5b63400a
...
@@ -21,22 +21,23 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
...
@@ -21,22 +21,23 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
a_device_buf
.
ToDevice
(
a
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a
.
mData
.
data
());
std
::
array
<
ck
::
index_t
,
4
>
a_lengths
,
b_lengths
;
std
::
array
<
ck
::
index_t
,
4
>
a_lengths
;
std
::
array
<
ck
::
index_t
,
4
>
axes
;
std
::
array
<
ck
::
index_t
,
4
>
a_strides
,
b_strides
;
std
::
array
<
ck
::
index_t
,
4
>
a_strides
,
b_strides
;
const
void
*
input
=
a_device_buf
.
GetDeviceBuffer
();
const
void
*
input
=
a_device_buf
.
GetDeviceBuffer
();
void
*
output
=
b_device_buf
.
GetDeviceBuffer
();
void
*
output
=
b_device_buf
.
GetDeviceBuffer
();
std
::
copy
(
begin
(
shape
),
end
(
shape
),
begin
(
a_lengths
));
std
::
copy
(
begin
(
shape
),
end
(
shape
),
begin
(
a_lengths
));
std
::
copy
(
begin
(
transposed_shape
),
end
(
transposed_shape
),
begin
(
b_length
s
));
std
::
copy
(
begin
(
problem
.
axes
),
end
(
problem
.
axes
),
begin
(
axe
s
));
std
::
copy
(
begin
(
a
.
mDesc
.
GetStrides
()),
end
(
a
.
mDesc
.
GetStrides
()),
begin
(
a_strides
));
std
::
copy
(
begin
(
a
.
mDesc
.
GetStrides
()),
end
(
a
.
mDesc
.
GetStrides
()),
begin
(
a_strides
));
std
::
copy
(
begin
(
b
.
mDesc
.
GetStrides
()),
end
(
b
.
mDesc
.
GetStrides
()),
begin
(
b_strides
));
std
::
copy
(
begin
(
b
.
mDesc
.
GetStrides
()),
end
(
b
.
mDesc
.
GetStrides
()),
begin
(
b_strides
));
static_assert
(
std
::
is_default_constructible_v
<
DevicePermuteInstance
>
);
static_assert
(
std
::
is_default_constructible_v
<
DevicePermuteInstance
>
);
auto
permute
=
DevicePermuteInstance
{};
auto
permute
=
DevicePermuteInstance
{};
auto
argument
=
permute
.
MakeArgument
(
auto
argument
=
a_lengths
,
a_strides
,
b_lengths
,
b_strides
,
input
,
output
,
PassThrough
{});
permute
.
MakeArgument
(
a_lengths
,
axes
,
a_strides
,
b_strides
,
input
,
output
,
PassThrough
{});
if
(
!
permute
.
IsSupportedArgument
(
argument
))
if
(
!
permute
.
IsSupportedArgument
(
argument
))
{
{
...
...
include/ck/tensor_operation/gpu/device/device_permute.hpp
View file @
5b63400a
...
@@ -162,8 +162,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
...
@@ -162,8 +162,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
inLengths
,
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
inLengths
,
const
std
::
array
<
index_t
,
NumDim
>
axes
,
const
std
::
array
<
index_t
,
NumDim
>
inStrides
,
const
std
::
array
<
index_t
,
NumDim
>
inStrides
,
const
std
::
array
<
index_t
,
NumDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDim
>
outStrides
,
const
std
::
array
<
index_t
,
NumDim
>
outStrides
,
const
void
*
in_dev_buffer
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
void
*
out_dev_buffer
,
...
@@ -171,8 +171,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
...
@@ -171,8 +171,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
:
blockSize_
(
256
),
:
blockSize_
(
256
),
gridSize_
(
120
),
// FIXME - Calculate the grid size by number of CU in the future
gridSize_
(
120
),
// FIXME - Calculate the grid size by number of CU in the future
inLengths_
(
inLengths
),
inLengths_
(
inLengths
),
axes_
(
axes
),
inStridesArray_
({
inStrides
}),
inStridesArray_
({
inStrides
}),
outLengths_
(
outLengths
),
outStridesArray_
({
outStrides
}),
outStridesArray_
({
outStrides
}),
elementwise_op_
(
elementwise_op
)
elementwise_op_
(
elementwise_op
)
{
{
...
@@ -196,7 +196,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
...
@@ -196,7 +196,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
out_grid_1d_desc_tuple_
=
generate_tuple
(
out_grid_1d_desc_tuple_
=
generate_tuple
(
[
&
](
auto
)
{
[
&
](
auto
)
{
return
MakeDescriptor_M
(
out
Lengths
,
outStrides
,
gridSize_
,
blockSize_
);
return
MakeDescriptor_M
(
in
Lengths
,
outStrides
,
gridSize_
,
blockSize_
);
},
},
Number
<
NumOutput
>
{});
Number
<
NumOutput
>
{});
}
}
...
@@ -210,8 +210,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
...
@@ -210,8 +210,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
OutGrid1dDescTuple
out_grid_1d_desc_tuple_
;
OutGrid1dDescTuple
out_grid_1d_desc_tuple_
;
std
::
array
<
index_t
,
NumDim
>
inLengths_
;
std
::
array
<
index_t
,
NumDim
>
inLengths_
;
std
::
array
<
index_t
,
NumDim
>
axes_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray_
;
std
::
array
<
index_t
,
NumDim
>
outLengths_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray_
;
ElementwiseOperation
elementwise_op_
;
ElementwiseOperation
elementwise_op_
;
...
@@ -244,7 +244,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
...
@@ -244,7 +244,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
arg
.
inLengths_
.
back
()
%
MPerThread
=
=
0
&&
arg
.
outLengths_
.
back
()
%
MPerThread
==
0
)
)
if
(
arg
.
inLengths_
.
back
()
%
MPerThread
!
=
0
)
{
{
return
false
;
return
false
;
}
}
...
@@ -270,7 +270,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
...
@@ -270,7 +270,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
if
(
!
IsScalarPerVectorValid
(
if
(
!
IsScalarPerVectorValid
(
arg
.
out
Lengths_
,
arg
.
outStridesArray_
[
I
.
value
],
OutScalarPerVectorSeq
::
At
(
I
)))
arg
.
in
Lengths_
,
arg
.
outStridesArray_
[
I
.
value
],
OutScalarPerVectorSeq
::
At
(
I
)))
valid
=
false
;
valid
=
false
;
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment