Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
339e51d1
Commit
339e51d1
authored
Sep 06, 2022
by
Po-Yen, Chen
Browse files
Simplify 'DevicePermute' interface
parent
5ae42120
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
46 deletions
+35
-46
example/36_permute/run_permute_example.inc
example/36_permute/run_permute_example.inc
+5
-5
include/ck/tensor_operation/gpu/device/device_permute.hpp
include/ck/tensor_operation/gpu/device/device_permute.hpp
+30
-41
No files found.
example/36_permute/run_permute_example.inc
View file @
339e51d1
...
@@ -22,14 +22,14 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
...
@@ -22,14 +22,14 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
a_device_buf
.
ToDevice
(
a
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a
.
mData
.
data
());
std
::
array
<
ck
::
index_t
,
4
>
ab_lengths
;
std
::
array
<
ck
::
index_t
,
4
>
ab_lengths
;
std
::
array
<
std
::
array
<
ck
::
index_t
,
4
>
,
1
>
a_strides
,
b_strides
;
std
::
array
<
ck
::
index_t
,
4
>
a_strides
,
b_strides
;
std
::
array
<
const
void
*
,
1
>
input
=
{
a_device_buf
.
GetDeviceBuffer
()
}
;
const
void
*
input
=
a_device_buf
.
GetDeviceBuffer
();
std
::
array
<
void
*
,
1
>
output
=
{
b_device_buf
.
GetDeviceBuffer
()
}
;
void
*
output
=
b_device_buf
.
GetDeviceBuffer
();
std
::
copy
(
begin
(
shape
),
end
(
shape
),
begin
(
ab_lengths
));
std
::
copy
(
begin
(
shape
),
end
(
shape
),
begin
(
ab_lengths
));
std
::
copy
(
begin
(
a
.
mDesc
.
GetStrides
()),
end
(
a
.
mDesc
.
GetStrides
()),
begin
(
front
(
a_strides
))
)
;
std
::
copy
(
begin
(
a
.
mDesc
.
GetStrides
()),
end
(
a
.
mDesc
.
GetStrides
()),
begin
(
a_strides
));
std
::
copy
(
begin
(
b
.
mDesc
.
GetStrides
()),
end
(
b
.
mDesc
.
GetStrides
()),
begin
(
front
(
b_strides
))
)
;
std
::
copy
(
begin
(
b
.
mDesc
.
GetStrides
()),
end
(
b
.
mDesc
.
GetStrides
()),
begin
(
b_strides
));
static_assert
(
std
::
is_default_constructible_v
<
DevicePermuteInstance
>
);
static_assert
(
std
::
is_default_constructible_v
<
DevicePermuteInstance
>
);
...
...
include/ck/tensor_operation/gpu/device/device_permute.hpp
View file @
339e51d1
...
@@ -118,25 +118,20 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
...
@@ -118,25 +118,20 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
return
PadDescriptor_M_1d
(
desc
,
gridSize
,
blockSize
);
return
PadDescriptor_M_1d
(
desc
,
gridSize
,
blockSize
);
}
}
template
<
index_t
TupleSize
>
static
auto
GenerateInOutGrid1dDesc
()
static
auto
GenerateInOutGrid1dDescTuple
(
Number
<
TupleSize
>
)
{
{
return
generate_tuple
(
if
constexpr
(
NumDim
>
1
)
[
&
](
auto
)
{
{
if
constexpr
(
NumDim
>
1
)
return
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
);
{
}
return
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
);
else
}
{
else
return
MakeDescriptor_M
({
1
},
{
1
},
1
,
1
);
{
};
return
MakeDescriptor_M
({
1
},
{
1
},
1
,
1
);
};
},
Number
<
TupleSize
>
{});
};
};
using
InGrid1dDescTuple
=
decltype
(
GenerateInOutGrid1dDesc
Tuple
(
Number
<
NumInput
>
{}
));
using
InGrid1dDescTuple
=
Tuple
<
decltype
(
GenerateInOutGrid1dDesc
(
))
>
;
using
OutGrid1dDescTuple
=
decltype
(
GenerateInOutGrid1dDesc
Tuple
(
Number
<
NumOutput
>
{}
));
using
OutGrid1dDescTuple
=
Tuple
<
decltype
(
GenerateInOutGrid1dDesc
(
))
>
;
using
GridwiseElementwise
=
GridwiseElementwise_1D
<
InGrid1dDescTuple
,
using
GridwiseElementwise
=
GridwiseElementwise_1D
<
InGrid1dDescTuple
,
OutGrid1dDescTuple
,
OutGrid1dDescTuple
,
...
@@ -150,48 +145,44 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
...
@@ -150,48 +145,44 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStrides
Array
,
const
std
::
array
<
index_t
,
NumDim
>
inStrides
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStrides
Array
,
const
std
::
array
<
index_t
,
NumDim
>
outStrides
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffer
s
,
const
void
*
in_dev_buffer
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffer
s
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
ElementwiseOperation
elementwise_op
)
:
blockSize_
(
256
),
:
lengths_
(
lengths
),
gridSize_
(
120
),
// FIXME - Calculate the grid size by number of CU in the future
inStridesArray_
(
inStridesArray
),
lengths_
(
lengths
),
outStridesArray_
(
outStridesArray
),
inStridesArray_
({
inStrides
}),
elementwise_op_
(
elementwise_op
),
outStridesArray_
({
outStrides
}),
blockSize_
(
256
),
elementwise_op_
(
elementwise_op
)
gridSize_
(
120
)
// FIXME - Calculate the grid size by number of CU in the future
{
{
in_dev_buffers_
=
generate_tuple
(
in_dev_buffers_
=
generate_tuple
(
[
&
](
auto
I
)
{
[
&
](
auto
)
{
using
DataType
=
InDataType
;
using
DataType
=
InDataType
;
return
static_cast
<
const
DataType
*>
(
in_dev_buffer
s
[
I
.
value
]
);
return
static_cast
<
const
DataType
*>
(
in_dev_buffer
);
},
},
Number
<
NumInput
>
{});
Number
<
NumInput
>
{});
out_dev_buffers_
=
generate_tuple
(
out_dev_buffers_
=
generate_tuple
(
[
&
](
auto
I
)
{
[
&
](
auto
)
{
using
DataType
=
OutDataType
;
using
DataType
=
OutDataType
;
return
static_cast
<
DataType
*>
(
out_dev_buffer
s
[
I
.
value
]
);
return
static_cast
<
DataType
*>
(
out_dev_buffer
);
},
},
Number
<
NumOutput
>
{});
Number
<
NumOutput
>
{});
in_grid_1d_desc_tuple_
=
generate_tuple
(
in_grid_1d_desc_tuple_
=
generate_tuple
(
[
&
](
auto
I
)
{
[
&
](
auto
)
{
return
MakeDescriptor_M
(
lengths
,
inStrides
,
gridSize_
,
blockSize_
);
},
return
MakeDescriptor_M
(
lengths
,
inStridesArray
[
I
.
value
],
gridSize_
,
blockSize_
);
},
Number
<
NumInput
>
{});
Number
<
NumInput
>
{});
out_grid_1d_desc_tuple_
=
generate_tuple
(
out_grid_1d_desc_tuple_
=
generate_tuple
(
[
&
](
auto
I
)
{
[
&
](
auto
)
{
return
MakeDescriptor_M
(
lengths
,
outStrides
,
gridSize_
,
blockSize_
);
},
return
MakeDescriptor_M
(
lengths
,
outStridesArray
[
I
.
value
],
gridSize_
,
blockSize_
);
},
Number
<
NumOutput
>
{});
Number
<
NumOutput
>
{});
}
}
index_t
blockSize_
;
index_t
gridSize_
;
InDataTypePointerTuple
in_dev_buffers_
;
InDataTypePointerTuple
in_dev_buffers_
;
OutDataTypePointerTuple
out_dev_buffers_
;
OutDataTypePointerTuple
out_dev_buffers_
;
InGrid1dDescTuple
in_grid_1d_desc_tuple_
;
InGrid1dDescTuple
in_grid_1d_desc_tuple_
;
...
@@ -202,8 +193,6 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
...
@@ -202,8 +193,6 @@ struct DevicePermute : detail::DevicePermuteBase<DevicePermute<InDataType,
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray_
;
ElementwiseOperation
elementwise_op_
;
ElementwiseOperation
elementwise_op_
;
index_t
blockSize_
;
index_t
gridSize_
;
};
};
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment