Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b56ddad3
"...composable_kernel-1.git" did not exist on "b491ebf38480bc0d6cb329ba6825dee610c59097"
Commit
b56ddad3
authored
Sep 15, 2022
by
Po-Yen, Chen
Browse files
Create new base type for 'DervicePermute' implementations
parent
b4e2b28c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
125 additions
and
66 deletions
+125
-66
include/ck/tensor_operation/gpu/device/device_permute.hpp
include/ck/tensor_operation/gpu/device/device_permute.hpp
+21
-66
include/ck/tensor_operation/gpu/device/device_permute_base.hpp
...de/ck/tensor_operation/gpu/device/device_permute_base.hpp
+104
-0
No files found.
include/ck/tensor_operation/gpu/device/device_permute.hpp
View file @
b56ddad3
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "ck/utility/math.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_permute_base.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
@@ -20,55 +21,6 @@ namespace ck {
...
@@ -20,55 +21,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
detail
{
template
<
typename
Derived
>
struct
DevicePermuteBase
:
BaseOperator
{
bool
IsSupportedArgument
(
const
BaseArgument
*
arg
)
override
final
{
const
auto
*
argument
=
dynamic_cast
<
const
typename
Derived
::
Argument
*>
(
arg
);
if
(
!
argument
)
{
return
false
;
}
return
Derived
::
IsSupportedArgument
(
*
argument
);
}
template
<
typename
...
Args
>
static
auto
MakeArgument
(
Args
&&
...
args
)
{
return
typename
Derived
::
Argument
{
std
::
forward
<
Args
>
(
args
)...};
}
template
<
typename
...
Args
>
static
auto
MakeArgumentPointer
(
Args
&&
...
args
)
{
return
std
::
make_unique
<
typename
Derived
::
Argument
>
(
std
::
forward
<
Args
>
(
args
)...);
}
static
auto
MakeInvoker
()
{
return
typename
Derived
::
Invoker
{};
}
static
auto
MakeInvokerPointer
()
{
return
std
::
make_unique
<
typename
Derived
::
Invoker
>
();
};
};
template
<
typename
Derived
,
typename
Argument
>
struct
InvokerBase
:
BaseInvoker
{
float
Run
(
const
BaseArgument
*
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
final
{
const
auto
*
argument
=
dynamic_cast
<
const
Argument
*>
(
arg
);
if
(
!
argument
)
{
return
0.
f
;
}
return
Derived
::
Run
(
*
argument
,
stream_config
);
}
};
}
// namespace detail
// Swap last 2 dimensions
// Swap last 2 dimensions
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
// ^^^^^^^^^^^
// ^^^^^^^^^^^
...
@@ -89,22 +41,25 @@ template <typename InDataType,
...
@@ -89,22 +41,25 @@ template <typename InDataType,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
>
index_t
DstScalarPerVector
>
struct
DevicePermute
struct
DevicePermute
:
DevicePermuteBaseCRTP
<
NumDim
,
:
detail
::
DevicePermuteBase
<
DevicePermute
<
InDataType
,
InDataType
,
OutDataType
,
OutDataType
,
ElementwiseOperation
,
ElementwiseOperation
,
NumDim
,
DevicePermute
<
InDataType
,
BlockSize
,
OutDataType
,
NPerBlock
,
ElementwiseOperation
,
HPerBlock
,
NumDim
,
WPerBlock
,
BlockSize
,
InBlockLdsExtraW
,
NPerBlock
,
InBlockTransferThreadClusterLengths
,
HPerBlock
,
InBlockTransferThreadClusterArrangeOrder
,
WPerBlock
,
SrcVectorDim
,
InBlockLdsExtraW
,
DstVectorDim
,
InBlockTransferThreadClusterLengths
,
SrcScalarPerVector
,
InBlockTransferThreadClusterArrangeOrder
,
DstScalarPerVector
>>
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
DstScalarPerVector
>>
{
{
static_assert
(
3
<=
NumDim
,
"Only accept at least 3D dimension tensor"
);
static_assert
(
3
<=
NumDim
,
"Only accept at least 3D dimension tensor"
);
static_assert
((
NumDim
-
2
)
<=
SrcVectorDim
&&
SrcVectorDim
<
NumDim
);
static_assert
((
NumDim
-
2
)
<=
SrcVectorDim
&&
SrcVectorDim
<
NumDim
);
...
@@ -204,7 +159,7 @@ struct DevicePermute
...
@@ -204,7 +159,7 @@ struct DevicePermute
typename
GridwisePermute
::
DefaultBlock2TileMap
block_2_tile_map_
;
typename
GridwisePermute
::
DefaultBlock2TileMap
block_2_tile_map_
;
};
};
struct
Invoker
:
detail
::
Invoker
Base
<
Invoker
,
Argument
>
struct
Invoker
:
Base
Invoker
CRTP
<
Invoker
,
Argument
>
{
{
static
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
static
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
...
...
include/ck/tensor_operation/gpu/device/device_permute_base.hpp
0 → 100644
View file @
b56ddad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <cmath>
#include <memory>
#include <type_traits>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NumDim
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
>
struct
DevicePermuteBase
:
BaseOperator
{
using
Lengths
=
std
::
array
<
index_t
,
NumDim
>
;
using
Strides
=
Lengths
;
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
Lengths
inLengths
,
const
Strides
inStrides
,
const
Lengths
outLengths
,
const
Strides
outStrides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
NumDim
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
typename
DerivedDeviceOperator
>
struct
DevicePermuteBaseCRTP
:
DevicePermuteBase
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
{
private:
using
BaseType
=
DevicePermuteBase
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
;
public:
// override methods inherited from 'BaseOperator'
bool
IsSupportedArgument
(
const
BaseArgument
*
arg
)
override
final
{
const
auto
*
const
argument
=
dynamic_cast
<
const
typename
DerivedDeviceOperator
::
Argument
*>
(
arg
);
if
(
!
argument
)
{
return
false
;
}
return
DerivedDeviceOperator
::
IsSupportedArgument
(
*
argument
);
}
// override methods inherited from 'DevicePermuteBase'
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
typename
BaseType
::
Lengths
inLengths
,
const
typename
BaseType
::
Strides
inStrides
,
const
typename
BaseType
::
Lengths
outLengths
,
const
typename
BaseType
::
Strides
outStrides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
override
final
{
return
std
::
make_unique
<
typename
DerivedDeviceOperator
::
Argument
>
(
inLengths
,
inStrides
,
outLengths
,
outStrides
,
in_dev_buffer
,
out_dev_buffer
,
elementwise_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
final
{
return
std
::
make_unique
<
typename
DerivedDeviceOperator
::
Invoker
>
();
};
// generate other utility methods
template
<
typename
...
Args
>
static
auto
MakeArgument
(
Args
&&
...
args
)
{
static_assert
(
std
::
is_constructible_v
<
typename
DerivedDeviceOperator
::
Argument
,
Args
...
>
);
return
typename
DerivedDeviceOperator
::
Argument
{
std
::
forward
<
Args
>
(
args
)...};
}
static
auto
MakeInvoker
()
noexcept
(
std
::
is_nothrow_default_constructible_v
<
typename
DerivedDeviceOperator
::
Invoker
>
)
{
static_assert
(
std
::
is_default_constructible_v
<
typename
DerivedDeviceOperator
::
Invoker
>
);
return
typename
DerivedDeviceOperator
::
Invoker
{};
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment