Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2e5d4f91
Commit
2e5d4f91
authored
Sep 07, 2022
by
Po-Yen, Chen
Browse files
Check if input/output shape meet the requirement
parent
b41e6019
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
12 deletions
+26
-12
example/36_permute/common.hpp
example/36_permute/common.hpp
+2
-2
example/36_permute/run_permute_example.inc
example/36_permute/run_permute_example.inc
+5
-6
include/ck/tensor_operation/gpu/device/device_permute.hpp
include/ck/tensor_operation/gpu/device/device_permute.hpp
+19
-4
No files found.
example/36_permute/common.hpp
View file @
2e5d4f91
...
...
@@ -31,8 +31,8 @@ struct ExecutionConfig final
struct
Problem
final
{
std
::
array
<
std
::
size_t
,
4
>
shape
=
{
4
,
16
,
32
,
32
};
std
::
array
<
std
::
size_t
,
4
>
axes
=
{
0
,
2
,
3
,
1
};
std
::
array
<
std
::
size_t
,
4
>
shape
=
{
4
,
8
,
16
,
32
};
std
::
array
<
std
::
size_t
,
4
>
axes
=
{
0
,
1
,
3
,
2
};
};
template
<
ck
::
index_t
...
Is
>
...
...
example/36_permute/run_permute_example.inc
View file @
2e5d4f91
...
...
@@ -21,23 +21,22 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
a_device_buf
.
ToDevice
(
a
.
mData
.
data
());
std
::
array
<
ck
::
index_t
,
4
>
a_lengths
;
std
::
array
<
ck
::
index_t
,
4
>
axes
;
std
::
array
<
ck
::
index_t
,
4
>
a_lengths
,
b_lengths
;
std
::
array
<
ck
::
index_t
,
4
>
a_strides
,
b_strides
;
const
void
*
input
=
a_device_buf
.
GetDeviceBuffer
();
void
*
output
=
b_device_buf
.
GetDeviceBuffer
();
std
::
copy
(
begin
(
shape
),
end
(
shape
),
begin
(
a_lengths
));
std
::
copy
(
begin
(
problem
.
axes
),
end
(
problem
.
axes
),
begin
(
axes
));
std
::
copy
(
begin
(
a
.
mDesc
.
GetStrides
()),
end
(
a
.
mDesc
.
GetStrides
()),
begin
(
a_strides
));
std
::
copy
(
begin
(
transposed_shape
),
end
(
transposed_shape
),
begin
(
b_lengths
));
std
::
copy
(
begin
(
b
.
mDesc
.
GetStrides
()),
end
(
b
.
mDesc
.
GetStrides
()),
begin
(
b_strides
));
static_assert
(
std
::
is_default_constructible_v
<
DevicePermuteInstance
>
);
auto
permute
=
DevicePermuteInstance
{};
auto
argument
=
permute
.
MakeArgument
(
a_lengths
,
axes
,
a_strides
,
b_strides
,
input
,
output
,
PassThrough
{});
auto
permute
=
DevicePermuteInstance
{};
auto
argument
=
permute
.
MakeArgument
(
a_lengths
,
a_strides
,
b_lengths
,
b_strides
,
input
,
output
,
PassThrough
{});
if
(
!
permute
.
IsSupportedArgument
(
argument
))
{
...
...
include/ck/tensor_operation/gpu/device/device_permute.hpp
View file @
2e5d4f91
...
...
@@ -68,6 +68,9 @@ struct InvokerBase : BaseInvoker
};
}
// namespace detail
// Swap last 2 dimensions
// input: [d0, d1, d2, ..., d, dn-2, dn-1]
// output: [d0, d1, d2, ..., d, dn-1, dn-2]
template
<
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
...
...
@@ -83,6 +86,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
InScalarPerVector
,
OutScalarPerVector
>>
{
static_assert
(
3
<=
NumDim
,
"Only accept at least 3D dimension tensor"
);
using
InDataTypePointer
=
const
InDataType
*
;
using
OutDataTypePointer
=
OutDataType
*
;
...
...
@@ -155,8 +160,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
inLengths
,
const
std
::
array
<
index_t
,
NumDim
>
axes
,
const
std
::
array
<
index_t
,
NumDim
>
inStrides
,
const
std
::
array
<
index_t
,
NumDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDim
>
outStrides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
...
...
@@ -168,8 +173,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
in_grid_1d_desc_
(
MakeDescriptor_M
(
inLengths
,
inStrides
,
gridSize_
,
blockSize_
)),
out_grid_1d_desc_
(
MakeDescriptor_M
(
inLengths
,
inStrides
,
gridSize_
,
blockSize_
)),
inLengths_
(
inLengths
),
axes_
(
axes
),
inStrides_
(
inStrides
),
outLengths_
(
outLengths
),
outStrides_
(
outStrides
),
elementwise_op_
(
elementwise_op
)
{
...
...
@@ -184,8 +189,8 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
OutGrid1dDesc
out_grid_1d_desc_
;
std
::
array
<
index_t
,
NumDim
>
inLengths_
;
std
::
array
<
index_t
,
NumDim
>
axes_
;
std
::
array
<
index_t
,
NumDim
>
inStrides_
;
std
::
array
<
index_t
,
NumDim
>
outLengths_
;
std
::
array
<
index_t
,
NumDim
>
outStrides_
;
ElementwiseOperation
elementwise_op_
;
...
...
@@ -223,6 +228,16 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
return
false
;
}
// check if only swap last 2 dimensions
if
(
!
(
std
::
equal
(
begin
(
arg
.
inLengths_
),
std
::
prev
(
end
(
arg
.
inLengths_
),
2
),
begin
(
arg
.
outLengths_
))
&&
std
::
tie
(
*
rbegin
(
arg
.
inLengths_
),
*
std
::
next
(
rbegin
(
arg
.
inLengths_
)))
==
std
::
tie
(
*
std
::
next
(
rbegin
(
arg
.
outLengths_
)),
*
rbegin
(
arg
.
outLengths_
))))
{
return
false
;
}
auto
IsScalarPerVectorValid
=
[
&
](
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
strides
,
index_t
scalarPerVector
)
{
...
...
@@ -241,7 +256,7 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
valid
=
false
;
}
if
(
!
IsScalarPerVectorValid
(
arg
.
in
Lengths_
,
arg
.
outStrides_
,
OutScalarPerVector
))
if
(
!
IsScalarPerVectorValid
(
arg
.
out
Lengths_
,
arg
.
outStrides_
,
OutScalarPerVector
))
{
valid
=
false
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment