Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
50f5ce49
"docs/source/en/api/models.mdx" did not exist on "5e6417e9887be8f02ab5b4f5c548dff7f3a4c8f6"
Commit
50f5ce49
authored
Sep 06, 2022
by
Po-Yen, Chen
Browse files
Distinguish input & output shape in 'DevicePermute'
parent
2377c2e8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
13 deletions
+21
-13
example/36_permute/run_permute_example.inc
example/36_permute/run_permute_example.inc
+6
-5
include/ck/tensor_operation/gpu/device/device_permute.hpp
include/ck/tensor_operation/gpu/device/device_permute.hpp
+15
-8
No files found.
example/36_permute/run_permute_example.inc
View file @
50f5ce49
...
...
@@ -21,21 +21,22 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
a_device_buf
.
ToDevice
(
a
.
mData
.
data
());
std
::
array
<
ck
::
index_t
,
4
>
ab_lengths
;
std
::
array
<
ck
::
index_t
,
4
>
a
_lengths
,
b_lengths
;
std
::
array
<
ck
::
index_t
,
4
>
a_strides
,
b_strides
;
const
void
*
input
=
a_device_buf
.
GetDeviceBuffer
();
void
*
output
=
b_device_buf
.
GetDeviceBuffer
();
std
::
copy
(
begin
(
shape
),
end
(
shape
),
begin
(
ab_lengths
));
std
::
copy
(
begin
(
shape
),
end
(
shape
),
begin
(
a_lengths
));
std
::
copy
(
begin
(
transposed_shape
),
end
(
transposed_shape
),
begin
(
b_lengths
));
std
::
copy
(
begin
(
a
.
mDesc
.
GetStrides
()),
end
(
a
.
mDesc
.
GetStrides
()),
begin
(
a_strides
));
std
::
copy
(
begin
(
b
.
mDesc
.
GetStrides
()),
end
(
b
.
mDesc
.
GetStrides
()),
begin
(
b_strides
));
static_assert
(
std
::
is_default_constructible_v
<
DevicePermuteInstance
>
);
auto
permute
=
DevicePermuteInstance
{};
auto
argument
=
permute
.
MakeArgument
(
ab
_lengths
,
a_strides
,
b_strides
,
input
,
output
,
PassThrough
{});
auto
permute
=
DevicePermuteInstance
{};
auto
argument
=
permute
.
MakeArgument
(
a
_lengths
,
a_strides
,
b_lengths
,
b_strides
,
input
,
output
,
PassThrough
{});
if
(
!
permute
.
IsSupportedArgument
(
argument
))
{
...
...
include/ck/tensor_operation/gpu/device/device_permute.hpp
View file @
50f5ce49
...
...
@@ -161,16 +161,18 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
l
engths
,
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
inL
engths
,
const
std
::
array
<
index_t
,
NumDim
>
inStrides
,
const
std
::
array
<
index_t
,
NumDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDim
>
outStrides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
:
blockSize_
(
256
),
gridSize_
(
120
),
// FIXME - Calculate the grid size by number of CU in the future
l
engths_
(
l
engths
),
inL
engths_
(
inL
engths
),
inStridesArray_
({
inStrides
}),
outLengths_
(
outLengths
),
outStridesArray_
({
outStrides
}),
elementwise_op_
(
elementwise_op
)
{
...
...
@@ -189,11 +191,13 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
Number
<
NumOutput
>
{});
in_grid_1d_desc_tuple_
=
generate_tuple
(
[
&
](
auto
)
{
return
MakeDescriptor_M
(
l
engths
,
inStrides
,
gridSize_
,
blockSize_
);
},
[
&
](
auto
)
{
return
MakeDescriptor_M
(
inL
engths
,
inStrides
,
gridSize_
,
blockSize_
);
},
Number
<
NumInput
>
{});
out_grid_1d_desc_tuple_
=
generate_tuple
(
[
&
](
auto
)
{
return
MakeDescriptor_M
(
lengths
,
outStrides
,
gridSize_
,
blockSize_
);
},
[
&
](
auto
)
{
return
MakeDescriptor_M
(
outLengths
,
outStrides
,
gridSize_
,
blockSize_
);
},
Number
<
NumOutput
>
{});
}
...
...
@@ -205,8 +209,9 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
InGrid1dDescTuple
in_grid_1d_desc_tuple_
;
OutGrid1dDescTuple
out_grid_1d_desc_tuple_
;
std
::
array
<
index_t
,
NumDim
>
l
engths_
;
std
::
array
<
index_t
,
NumDim
>
inL
engths_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray_
;
std
::
array
<
index_t
,
NumDim
>
outLengths_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray_
;
ElementwiseOperation
elementwise_op_
;
...
...
@@ -239,8 +244,10 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
arg
.
lengths_
.
back
()
%
MPerThread
!=
0
)
if
(
!
(
arg
.
inLengths_
.
back
()
%
MPerThread
==
0
&&
arg
.
outLengths_
.
back
()
%
MPerThread
==
0
))
{
return
false
;
}
auto
IsScalarPerVectorValid
=
[
&
](
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
strides
,
...
...
@@ -257,13 +264,13 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
bool
valid
=
true
;
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
if
(
!
IsScalarPerVectorValid
(
arg
.
l
engths_
,
arg
.
inStridesArray_
[
I
.
value
],
InScalarPerVectorSeq
::
At
(
I
)))
arg
.
inL
engths_
,
arg
.
inStridesArray_
[
I
.
value
],
InScalarPerVectorSeq
::
At
(
I
)))
valid
=
false
;
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
if
(
!
IsScalarPerVectorValid
(
arg
.
l
engths_
,
arg
.
outStridesArray_
[
I
.
value
],
OutScalarPerVectorSeq
::
At
(
I
)))
arg
.
outL
engths_
,
arg
.
outStridesArray_
[
I
.
value
],
OutScalarPerVectorSeq
::
At
(
I
)))
valid
=
false
;
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment