Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ee40f5a9
Commit
ee40f5a9
authored
Sep 15, 2022
by
Po-Yen, Chen
Browse files
Use type alias to reduce code
parent
f17fa4d7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
103 deletions
+72
-103
include/ck/tensor_operation/gpu/device/device_permute.hpp
include/ck/tensor_operation/gpu/device/device_permute.hpp
+4
-71
include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp
.../tensor_operation/gpu/device/impl/device_permute_impl.hpp
+68
-32
No files found.
include/ck/tensor_operation/gpu/device/device_permute.hpp
View file @
ee40f5a9
...
@@ -21,10 +21,10 @@ struct DevicePermute : BaseOperator
...
@@ -21,10 +21,10 @@ struct DevicePermute : BaseOperator
using
Strides
=
Lengths
;
using
Strides
=
Lengths
;
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
Lengths
inLengths
,
MakeArgumentPointer
(
const
Lengths
&
inLengths
,
const
Strides
inStrides
,
const
Strides
&
inStrides
,
const
Lengths
outLengths
,
const
Lengths
&
outLengths
,
const
Strides
outStrides
,
const
Strides
&
outStrides
,
const
void
*
in_dev_buffer
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
=
0
;
ElementwiseOperation
elementwise_op
)
=
0
;
...
@@ -32,73 +32,6 @@ struct DevicePermute : BaseOperator
...
@@ -32,73 +32,6 @@ struct DevicePermute : BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
index_t
NumDim
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
typename
DerivedDeviceOperator
>
struct
DevicePermuteCRTP
:
DevicePermute
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
{
private:
using
BaseType
=
DevicePermute
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
;
public:
// override methods inherited from 'BaseOperator'
bool
IsSupportedArgument
(
const
BaseArgument
*
arg
)
override
final
{
const
auto
*
const
argument
=
dynamic_cast
<
const
typename
DerivedDeviceOperator
::
Argument
*>
(
arg
);
if
(
!
argument
)
{
return
false
;
}
return
DerivedDeviceOperator
::
IsSupportedArgument
(
*
argument
);
}
// override methods inherited from 'DevicePermute'
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
typename
BaseType
::
Lengths
inLengths
,
const
typename
BaseType
::
Strides
inStrides
,
const
typename
BaseType
::
Lengths
outLengths
,
const
typename
BaseType
::
Strides
outStrides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
override
final
{
return
std
::
make_unique
<
typename
DerivedDeviceOperator
::
Argument
>
(
inLengths
,
inStrides
,
outLengths
,
outStrides
,
in_dev_buffer
,
out_dev_buffer
,
elementwise_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
final
{
return
std
::
make_unique
<
typename
DerivedDeviceOperator
::
Invoker
>
();
};
// generate other utility methods
template
<
typename
...
Args
>
static
auto
MakeArgument
(
Args
&&
...
args
)
noexcept
(
std
::
is_nothrow_constructible_v
<
typename
DerivedDeviceOperator
::
Argument
,
Args
...
>
)
{
static_assert
(
std
::
is_constructible_v
<
typename
DerivedDeviceOperator
::
Argument
,
Args
...
>
);
return
typename
DerivedDeviceOperator
::
Argument
{
std
::
forward
<
Args
>
(
args
)...};
}
static
auto
MakeInvoker
()
noexcept
(
std
::
is_nothrow_default_constructible_v
<
typename
DerivedDeviceOperator
::
Invoker
>
)
{
static_assert
(
std
::
is_default_constructible_v
<
typename
DerivedDeviceOperator
::
Invoker
>
);
return
typename
DerivedDeviceOperator
::
Invoker
{};
}
};
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp
View file @
ee40f5a9
...
@@ -41,27 +41,12 @@ template <index_t NumDim,
...
@@ -41,27 +41,12 @@ template <index_t NumDim,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
>
index_t
DstScalarPerVector
>
struct
DevicePermuteImpl
struct
DevicePermuteImpl
:
DevicePermute
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
:
DevicePermuteCRTP
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
DevicePermuteImpl
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
BlockSize
,
NPerBlock
,
HPerBlock
,
WPerBlock
,
InBlockLdsExtraW
,
InBlockTransferThreadClusterLengths
,
InBlockTransferThreadClusterArrangeOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
DstScalarPerVector
>>
{
{
using
BaseType
=
DevicePermute
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
;
using
typename
BaseType
::
Lengths
;
using
typename
BaseType
::
Strides
;
static_assert
(
3
<=
NumDim
,
"Only accept at least 3D dimension tensor"
);
static_assert
(
3
<=
NumDim
,
"Only accept at least 3D dimension tensor"
);
static_assert
((
NumDim
-
2
)
<=
SrcVectorDim
&&
SrcVectorDim
<
NumDim
);
static_assert
((
NumDim
-
2
)
<=
SrcVectorDim
&&
SrcVectorDim
<
NumDim
);
static_assert
((
NumDim
-
2
)
<=
DstVectorDim
&&
DstVectorDim
<
NumDim
);
static_assert
((
NumDim
-
2
)
<=
DstVectorDim
&&
DstVectorDim
<
NumDim
);
...
@@ -75,8 +60,7 @@ struct DevicePermuteImpl
...
@@ -75,8 +60,7 @@ struct DevicePermuteImpl
return
generate_tuple
([
&
](
auto
I
)
{
return
array
[
I
];
},
Number
<
N
>
{});
return
generate_tuple
([
&
](
auto
I
)
{
return
array
[
I
];
},
Number
<
N
>
{});
}
}
static
auto
MakeDescriptor_N_H_W
(
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
static
auto
MakeDescriptor_N_H_W
(
const
Lengths
&
lengths
,
const
Strides
stride
)
const
std
::
array
<
index_t
,
NumDim
>&
stride
)
{
{
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
// d[NumDim-1]]
// d[NumDim-1]]
...
@@ -123,12 +107,14 @@ struct DevicePermuteImpl
...
@@ -123,12 +107,14 @@ struct DevicePermuteImpl
SrcScalarPerVector
,
SrcScalarPerVector
,
DstScalarPerVector
>
;
DstScalarPerVector
>
;
using
Block2TileMap
=
typename
GridwisePermute
::
DefaultBlock2TileMap
;
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
inLengths
,
Argument
(
const
Lengths
&
inLengths
,
const
std
::
array
<
index_t
,
NumDim
>
inStrides
,
const
Strides
&
inStrides
,
const
std
::
array
<
index_t
,
NumDim
>
outLengths
,
const
Lengths
&
outLengths
,
const
std
::
array
<
index_t
,
NumDim
>
outStrides
,
const
Strides
&
outStrides
,
const
void
*
in_dev_buffer
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
ElementwiseOperation
elementwise_op
)
...
@@ -150,14 +136,14 @@ struct DevicePermuteImpl
...
@@ -150,14 +136,14 @@ struct DevicePermuteImpl
InGridDesc
in_grid_desc_
;
InGridDesc
in_grid_desc_
;
OutGridDesc
out_grid_desc_
;
OutGridDesc
out_grid_desc_
;
std
::
array
<
index_t
,
NumDim
>
inLengths_
;
Lengths
inLengths_
;
std
::
array
<
index_t
,
NumDim
>
inStrides_
;
Strides
inStrides_
;
std
::
array
<
index_t
,
NumDim
>
outLengths_
;
Lengths
outLengths_
;
std
::
array
<
index_t
,
NumDim
>
outStrides_
;
Strides
outStrides_
;
ElementwiseOperation
elementwise_op_
;
ElementwiseOperation
elementwise_op_
;
typename
GridwisePermute
::
Default
Block2TileMap
block_2_tile_map_
;
Block2TileMap
block_2_tile_map_
;
};
};
struct
Invoker
:
BaseInvokerCRTP
<
Invoker
,
Argument
>
struct
Invoker
:
BaseInvokerCRTP
<
Invoker
,
Argument
>
...
@@ -172,7 +158,7 @@ struct DevicePermuteImpl
...
@@ -172,7 +158,7 @@ struct DevicePermuteImpl
InDataType
,
InDataType
,
OutDataType
,
OutDataType
,
ElementwiseOperation
,
ElementwiseOperation
,
typename
GridwisePermute
::
Default
Block2TileMap
>
;
Block2TileMap
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
@@ -227,6 +213,56 @@ struct DevicePermuteImpl
...
@@ -227,6 +213,56 @@ struct DevicePermuteImpl
DstScalarPerVector
)
&&
DstScalarPerVector
)
&&
GridwisePermute
::
CheckValidity
(
arg
.
in_grid_desc_
,
arg
.
out_grid_desc_
);
GridwisePermute
::
CheckValidity
(
arg
.
in_grid_desc_
,
arg
.
out_grid_desc_
);
};
};
// override methods inherited from 'BaseOperator'
bool
IsSupportedArgument
(
const
BaseArgument
*
arg
)
override
final
{
const
auto
*
const
argument
=
dynamic_cast
<
const
Argument
*>
(
arg
);
if
(
!
argument
)
{
return
false
;
}
return
IsSupportedArgument
(
*
argument
);
}
// override methods inherited from 'DevicePermute'
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
Lengths
&
inLengths
,
const
Strides
&
inStrides
,
const
Lengths
&
outLengths
,
const
Strides
&
outStrides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
override
final
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
outLengths
,
outStrides
,
in_dev_buffer
,
out_dev_buffer
,
elementwise_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
final
{
return
std
::
make_unique
<
Invoker
>
();
};
// other constructor methods
template
<
typename
...
Args
>
static
std
::
enable_if_t
<
std
::
is_constructible_v
<
Argument
,
Args
...
>
,
Argument
>
MakeArgument
(
Args
&&
...
args
)
noexcept
(
std
::
is_nothrow_constructible_v
<
Argument
,
Args
...
>
)
{
return
Argument
{
std
::
forward
<
Args
>
(
args
)...};
}
static
std
::
enable_if_t
<
std
::
is_default_constructible_v
<
Invoker
>
,
Invoker
>
MakeInvoker
()
noexcept
(
std
::
is_nothrow_default_constructible_v
<
Invoker
>
)
{
return
Invoker
{};
}
};
};
}
// namespace device
}
// namespace device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment