Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
32a2d78b
Commit
32a2d78b
authored
Sep 06, 2022
by
Po-Yen, Chen
Browse files
Use indirect base type to generate methods
parent
e53b50e8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
54 deletions
+48
-54
example/36_permute/common.hpp
example/36_permute/common.hpp
+6
-0
example/36_permute/run_permute_example.inc
example/36_permute/run_permute_example.inc
+7
-7
include/ck/tensor_operation/gpu/device/device_permute.hpp
include/ck/tensor_operation/gpu/device/device_permute.hpp
+35
-47
No files found.
example/36_permute/common.hpp
View file @
32a2d78b
...
...
@@ -207,6 +207,12 @@ inline constexpr bool is_device_op_v = is_device_op<T>::value;
}
// namespace detail
template
<
typename
Range
>
auto
front
(
Range
&&
range
)
->
decltype
(
std
::
forward
<
Range
>
(
range
).
front
())
{
return
std
::
forward
<
Range
>
(
range
).
front
();
}
template
<
typename
Axes
>
inline
std
::
enable_if_t
<
detail
::
is_random_access_range_v
<
Axes
>
,
bool
>
is_valid_axes
(
const
Axes
&
axes
)
...
...
example/36_permute/run_permute_example.inc
View file @
32a2d78b
...
...
@@ -21,22 +21,22 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
a_device_buf
.
ToDevice
(
a
.
mData
.
data
());
std
::
array
<
ck
::
index_t
,
4
>
ab_lengths
;
std
::
array
<
std
::
array
<
ck
::
index_t
,
4
>
,
1
>
a_strides
,
b_strides
;
std
::
array
<
const
void
*
,
1
>
input
=
{
a_device_buf
.
GetDeviceBuffer
()};
std
::
array
<
void
*
,
1
>
output
=
{
b_device_buf
.
GetDeviceBuffer
()};
std
::
array
<
ck
::
index_t
,
4
>
ab_lengths
;
std
::
array
<
ck
::
index_t
,
4
>
a_strides
,
b_strides
;
std
::
copy
(
begin
(
shape
),
end
(
shape
),
begin
(
ab_lengths
));
std
::
copy
(
begin
(
a
.
mDesc
.
GetStrides
()),
end
(
a
.
mDesc
.
GetStrides
()),
begin
(
a_strides
));
std
::
copy
(
begin
(
b
.
mDesc
.
GetStrides
()),
end
(
b
.
mDesc
.
GetStrides
()),
begin
(
b_strides
));
std
::
copy
(
begin
(
a
.
mDesc
.
GetStrides
()),
end
(
a
.
mDesc
.
GetStrides
()),
begin
(
front
(
a_strides
))
)
;
std
::
copy
(
begin
(
b
.
mDesc
.
GetStrides
()),
end
(
b
.
mDesc
.
GetStrides
()),
begin
(
front
(
b_strides
))
)
;
static_assert
(
std
::
is_default_constructible_v
<
DevicePermuteInstance
>
);
static_assert
(
detail
::
is_device_op_v
<
DevicePermuteInstance
>
);
//
static_assert(detail::is_device_op_v<DevicePermuteInstance>);
auto
permute
=
DevicePermuteInstance
{};
auto
argument
=
permute
.
MakeArgument
(
ab_lengths
,
{
a_strides
}
,
{
b_strides
}
,
input
,
output
,
PassThrough
{});
permute
.
MakeArgument
(
ab_lengths
,
a_strides
,
b_strides
,
input
,
output
,
PassThrough
{});
if
(
!
permute
.
IsSupportedArgument
(
argument
))
{
...
...
include/ck/tensor_operation/gpu/device/device_permute.hpp
View file @
32a2d78b
...
...
@@ -18,10 +18,38 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
detail
{
template
<
typename
Derived
>
struct
DevicePermuteBase
:
BaseOperator
{
bool
IsSupportedArgument
(
const
BaseArgument
*
arg
)
override
final
{
const
auto
*
argument
=
dynamic_cast
<
const
typename
Derived
::
Argument
*>
(
arg
);
if
(
!
argument
)
{
return
false
;
}
return
Derived
::
IsSupportedArgument
(
*
argument
);
}
template
<
typename
...
Args
>
static
auto
MakeArgument
(
Args
&&
...
args
)
{
return
typename
Derived
::
Argument
{
std
::
forward
<
Args
>
(
args
)...};
}
template
<
typename
...
Args
>
static
auto
MakeArgumentPointer
(
Args
&&
...
args
)
{
return
std
::
make_unique
<
typename
Derived
::
Argument
>
(
std
::
forward
<
Args
>
(
args
)...);
}
static
auto
MakeInvoker
()
{
return
typename
Derived
::
Invoker
{};
}
static
auto
MakeInvokerPointer
()
{
return
std
::
make_unique
<
typename
Derived
::
Invoker
>
();
};
};
}
// namespace detail
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
...
...
@@ -30,13 +58,13 @@ template <typename InDataTypeTuple,
index_t
MPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
DevicePermute
:
DevicePermuteBase
<
DevicePermute
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim
,
MPerThread
,
InScalarPerVectorSeq
,
OutScalarPerVectorSeq
>>
struct
DevicePermute
:
detail
::
DevicePermuteBase
<
DevicePermute
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim
,
MPerThread
,
InScalarPerVectorSeq
,
OutScalarPerVectorSeq
>>
{
static
constexpr
int
NumInput
=
InDataTypeTuple
::
Size
();
static
constexpr
int
NumOutput
=
OutDataTypeTuple
::
Size
();
...
...
@@ -264,46 +292,6 @@ struct DevicePermute : DevicePermuteBase<DevicePermute<InDataTypeTuple,
return
valid
;
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
)
{
return
Argument
{
lengths
,
inStridesArray
,
outStridesArray
,
in_dev_buffers
,
out_dev_buffers
,
elementwise_op
};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
)
{
return
std
::
make_unique
<
Argument
>
(
lengths
,
inStridesArray
,
outStridesArray
,
in_dev_buffers
,
out_dev_buffers
,
elementwise_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
();
};
};
// namespace device
}
// namespace device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment