Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
734a12da
Commit
734a12da
authored
Sep 15, 2022
by
Po-Yen, Chen
Browse files
Rename 'DevicePermute' to 'DevicePermuteImpl'
parent
16b116a9
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
303 additions
and
303 deletions
+303
-303
example/37_permute/common.hpp
example/37_permute/common.hpp
+1
-1
example/37_permute/permute_1xHxW_fp16.cpp
example/37_permute/permute_1xHxW_fp16.cpp
+1
-1
example/37_permute/permute_HxWx4_fp16.cpp
example/37_permute/permute_HxWx4_fp16.cpp
+1
-1
example/37_permute/permute_NxHxW_fp16.cpp
example/37_permute/permute_NxHxW_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/device/device_permute.hpp
include/ck/tensor_operation/gpu/device/device_permute.hpp
+65
-195
include/ck/tensor_operation/gpu/device/device_permute_base.hpp
...de/ck/tensor_operation/gpu/device/device_permute_base.hpp
+0
-104
include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp
.../tensor_operation/gpu/device/impl/device_permute_impl.hpp
+234
-0
No files found.
example/37_permute/common.hpp
View file @
734a12da
...
@@ -15,8 +15,8 @@
...
@@ -15,8 +15,8 @@
#include <utility>
#include <utility>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_permute.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
example/37_permute/permute_1xHxW_fp16.cpp
View file @
734a12da
...
@@ -7,7 +7,7 @@ using InDataType = F16;
...
@@ -7,7 +7,7 @@ using InDataType = F16;
using
OutDataType
=
F16
;
using
OutDataType
=
F16
;
// clang-format off
// clang-format off
using
DevicePermuteInstance
=
ck
::
tensor_operation
::
device
::
DevicePermute
using
DevicePermuteInstance
=
ck
::
tensor_operation
::
device
::
DevicePermute
Impl
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | |
...
...
example/37_permute/permute_HxWx4_fp16.cpp
View file @
734a12da
...
@@ -9,7 +9,7 @@ using BundleType = F64;
...
@@ -9,7 +9,7 @@ using BundleType = F64;
static_assert
(
sizeof
(
BundleType
)
%
sizeof
(
DataType
)
==
0
);
static_assert
(
sizeof
(
BundleType
)
%
sizeof
(
DataType
)
==
0
);
// clang-format off
// clang-format off
using
DevicePermuteInstance
=
ck
::
tensor_operation
::
device
::
DevicePermute
using
DevicePermuteInstance
=
ck
::
tensor_operation
::
device
::
DevicePermute
Impl
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | |
...
...
example/37_permute/permute_NxHxW_fp16.cpp
View file @
734a12da
...
@@ -7,7 +7,7 @@ using InDataType = F16;
...
@@ -7,7 +7,7 @@ using InDataType = F16;
using
OutDataType
=
F16
;
using
OutDataType
=
F16
;
// clang-format off
// clang-format off
using
DevicePermuteInstance
=
ck
::
tensor_operation
::
device
::
DevicePermute
using
DevicePermuteInstance
=
ck
::
tensor_operation
::
device
::
DevicePermute
Impl
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | |
...
...
include/ck/tensor_operation/gpu/device/device_permute.hpp
View file @
734a12da
...
@@ -4,228 +4,98 @@
...
@@ -4,228 +4,98 @@
#pragma once
#pragma once
#include <array>
#include <array>
#include <cmath>
#include <memory>
#include <memory>
#include <
util
it
y
>
#include <
type_tra
it
s
>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_permute_base.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
// Swap last 2 dimensions
template
<
index_t
NumDim
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
>
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
struct
DevicePermute
:
BaseOperator
// ^^^^^^^^^^^
{
// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]]
using
Lengths
=
std
::
array
<
index_t
,
NumDim
>
;
// ^^^^^^^^^^^
using
Strides
=
Lengths
;
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
Lengths
inLengths
,
const
Strides
inStrides
,
const
Lengths
outLengths
,
const
Strides
outStrides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
NumDim
,
template
<
index_t
NumDim
,
typename
InDataType
,
typename
InDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
typename
ElementwiseOperation
,
index_t
BlockSize
,
typename
DerivedDeviceOperator
>
index_t
NPerBlock
,
struct
DevicePermuteCRTP
:
DevicePermute
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
index_t
HPerBlock
,
index_t
WPerBlock
,
index_t
InBlockLdsExtraW
,
typename
InBlockTransferThreadClusterLengths
,
typename
InBlockTransferThreadClusterArrangeOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
>
struct
DevicePermute
:
DevicePermuteBaseCRTP
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
DevicePermute
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
BlockSize
,
NPerBlock
,
HPerBlock
,
WPerBlock
,
InBlockLdsExtraW
,
InBlockTransferThreadClusterLengths
,
InBlockTransferThreadClusterArrangeOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
DstScalarPerVector
>>
{
{
static_assert
(
3
<=
NumDim
,
"Only accept at least 3D dimension tensor"
);
private:
static_assert
((
NumDim
-
2
)
<=
SrcVectorDim
&&
SrcVectorDim
<
NumDim
);
using
BaseType
=
DevicePermute
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
;
static_assert
((
NumDim
-
2
)
<=
DstVectorDim
&&
DstVectorDim
<
NumDim
);
static_assert
(
SrcVectorDim
!=
DstVectorDim
);
template
<
index_t
N
=
NumDim
>
public:
static
auto
ConvertArrayToTuple
(
const
std
::
array
<
index_t
,
NumDim
>&
array
)
// override methods inherited from 'BaseOperator'
bool
IsSupportedArgument
(
const
BaseArgument
*
arg
)
override
final
{
{
static_assert
(
1
<=
N
&&
N
<=
NumDim
);
const
auto
*
const
argument
=
dynamic_cast
<
const
typename
DerivedDeviceOperator
::
Argument
*>
(
arg
);
if
(
!
argument
)
{
return
false
;
}
return
generate_tuple
([
&
](
auto
I
)
{
return
array
[
I
];
},
Number
<
N
>
{}
);
return
DerivedDeviceOperator
::
IsSupportedArgument
(
*
argument
);
}
}
static
auto
MakeDescriptor_N_H_W
(
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
// override methods inherited from 'DevicePermute'
const
std
::
array
<
index_t
,
NumDim
>&
stride
)
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
typename
BaseType
::
Lengths
inLengths
,
const
typename
BaseType
::
Strides
inStrides
,
const
typename
BaseType
::
Lengths
outLengths
,
const
typename
BaseType
::
Strides
outStrides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
override
final
{
{
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
return
std
::
make_unique
<
typename
DerivedDeviceOperator
::
Argument
>
(
inLengths
,
// d[NumDim-1]]
inStrides
,
const
auto
desc
=
outLengths
,
make_naive_tensor_descriptor
(
ConvertArrayToTuple
(
lengths
),
ConvertArrayToTuple
(
stride
));
outStrides
,
in_dev_buffer
,
// merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2],
out_dev_buffer
,
// d[NumDim-1]]
elementwise_op
);
// => [N, H, W]
const
index_t
H
=
*
std
::
next
(
rbegin
(
lengths
));
const
index_t
W
=
*
rbegin
(
lengths
);
const
auto
desc_n_h_w
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
ConvertArrayToTuple
<
NumDim
-
2
>
(
lengths
)),
make_pass_through_transform
(
H
),
make_pass_through_transform
(
W
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
NumDim
-
2
>
{}),
Sequence
<
NumDim
-
2
>
{},
Sequence
<
NumDim
-
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
PadTensorDescriptor
(
desc_n_h_w
,
make_tuple
(
NPerBlock
,
HPerBlock
,
WPerBlock
),
Sequence
<
true
,
true
,
true
>
{});
}
}
using
InGridDesc
=
decltype
(
MakeDescriptor_N_H_W
({
1
,
1
},
{
1
,
1
}));
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
final
using
OutGridDesc
=
InGridDesc
;
using
GridwisePermute
=
GridwisePermute
<
InGridDesc
,
OutGridDesc
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
BlockSize
,
NPerBlock
,
HPerBlock
,
WPerBlock
,
InBlockLdsExtraW
,
InBlockTransferThreadClusterLengths
,
InBlockTransferThreadClusterArrangeOrder
,
SrcVectorDim
-
(
NumDim
-
3
),
// calculate new SrcVectorDim for the merged descriptor
DstVectorDim
-
(
NumDim
-
3
),
// calculate new DstVectorDim for the merged descriptor
SrcScalarPerVector
,
DstScalarPerVector
>
;
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
inLengths
,
return
std
::
make_unique
<
typename
DerivedDeviceOperator
::
Invoker
>
();
const
std
::
array
<
index_t
,
NumDim
>
inStrides
,
const
std
::
array
<
index_t
,
NumDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDim
>
outStrides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
:
in_dev_buffer_
(
static_cast
<
const
InDataType
*>
(
in_dev_buffer
)),
out_dev_buffer_
(
static_cast
<
OutDataType
*>
(
out_dev_buffer
)),
in_grid_desc_
(
MakeDescriptor_N_H_W
(
inLengths
,
inStrides
)),
out_grid_desc_
(
MakeDescriptor_N_H_W
(
outLengths
,
outStrides
)),
inLengths_
(
inLengths
),
inStrides_
(
inStrides
),
outLengths_
(
outLengths
),
outStrides_
(
outStrides
),
elementwise_op_
(
elementwise_op
),
block_2_tile_map_
(
GridwisePermute
::
MakeDefaultBlock2TileMap
(
in_grid_desc_
))
{
}
const
InDataType
*
in_dev_buffer_
;
OutDataType
*
out_dev_buffer_
;
InGridDesc
in_grid_desc_
;
OutGridDesc
out_grid_desc_
;
std
::
array
<
index_t
,
NumDim
>
inLengths_
;
std
::
array
<
index_t
,
NumDim
>
inStrides_
;
std
::
array
<
index_t
,
NumDim
>
outLengths_
;
std
::
array
<
index_t
,
NumDim
>
outStrides_
;
ElementwiseOperation
elementwise_op_
;
typename
GridwisePermute
::
DefaultBlock2TileMap
block_2_tile_map_
;
};
};
struct
Invoker
:
BaseInvokerCRTP
<
Invoker
,
Argument
>
// generate other utility methods
template
<
typename
...
Args
>
static
auto
MakeArgument
(
Args
&&
...
args
)
{
{
static
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
static_assert
(
std
::
is_constructible_v
<
typename
DerivedDeviceOperator
::
Argument
,
Args
...
>
);
{
const
index_t
grid_size
=
arg
.
block_2_tile_map_
.
CalculateGridSize
(
arg
.
in_grid_desc_
);
const
auto
kernel
=
kernel_nd_permute
<
GridwisePermute
,
InGridDesc
,
OutGridDesc
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
typename
GridwisePermute
::
DefaultBlock2TileMap
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
in_grid_desc_
,
arg
.
out_grid_desc_
,
arg
.
in_dev_buffer_
,
arg
.
out_dev_buffer_
,
arg
.
elementwise_op_
,
arg
.
block_2_tile_map_
);
return
elapsed_time
;
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
return
typename
DerivedDeviceOperator
::
Argument
{
std
::
forward
<
Args
>
(
args
)...};
}
static
auto
MakeInvoker
()
noexcept
(
std
::
is_nothrow_default_constructible_v
<
typename
DerivedDeviceOperator
::
Invoker
>
)
{
{
constexpr
auto
GetPaddedLength
=
[](
index_t
length
,
index_t
tile_length
)
{
static_assert
(
std
::
is_default_constructible_v
<
typename
DerivedDeviceOperator
::
Invoker
>
);
return
math
::
integer_divide_ceil
(
length
,
tile_length
)
*
tile_length
;
};
return
typename
DerivedDeviceOperator
::
Invoker
{};
}
constexpr
auto
IsScalarPerVectorValid
=
[](
index_t
length
,
index_t
stride
,
index_t
scalar_per_vector
)
{
if
(
stride
==
1
&&
length
%
scalar_per_vector
==
0
)
{
return
true
;
}
else
if
(
stride
!=
1
&&
scalar_per_vector
==
1
)
{
return
true
;
}
return
false
;
};
return
IsScalarPerVectorValid
(
arg
.
inLengths_
[
SrcVectorDim
],
arg
.
inStrides_
[
SrcVectorDim
],
SrcScalarPerVector
)
&&
IsScalarPerVectorValid
(
GetPaddedLength
(
arg
.
inLengths_
[
SrcVectorDim
],
(
SrcVectorDim
==
NumDim
-
2
?
HPerBlock
:
WPerBlock
)),
arg
.
inStrides_
[
SrcVectorDim
],
SrcScalarPerVector
)
&&
IsScalarPerVectorValid
(
arg
.
outLengths_
[
DstVectorDim
],
arg
.
outStrides_
[
DstVectorDim
],
DstScalarPerVector
)
&&
IsScalarPerVectorValid
(
GetPaddedLength
(
arg
.
outLengths_
[
DstVectorDim
],
(
DstVectorDim
==
NumDim
-
2
?
HPerBlock
:
WPerBlock
)),
arg
.
inStrides_
[
DstVectorDim
],
DstScalarPerVector
)
&&
GridwisePermute
::
CheckValidity
(
arg
.
in_grid_desc_
,
arg
.
out_grid_desc_
);
};
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_permute_base.hpp
deleted
100644 → 0
View file @
16b116a9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <cmath>
#include <memory>
#include <type_traits>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NumDim
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
>
struct
DevicePermuteBase
:
BaseOperator
{
using
Lengths
=
std
::
array
<
index_t
,
NumDim
>
;
using
Strides
=
Lengths
;
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
Lengths
inLengths
,
const
Strides
inStrides
,
const
Lengths
outLengths
,
const
Strides
outStrides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
NumDim
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
typename
DerivedDeviceOperator
>
struct
DevicePermuteBaseCRTP
:
DevicePermuteBase
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
{
private:
using
BaseType
=
DevicePermuteBase
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
;
public:
// override methods inherited from 'BaseOperator'
bool
IsSupportedArgument
(
const
BaseArgument
*
arg
)
override
final
{
const
auto
*
const
argument
=
dynamic_cast
<
const
typename
DerivedDeviceOperator
::
Argument
*>
(
arg
);
if
(
!
argument
)
{
return
false
;
}
return
DerivedDeviceOperator
::
IsSupportedArgument
(
*
argument
);
}
// override methods inherited from 'DevicePermuteBase'
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
typename
BaseType
::
Lengths
inLengths
,
const
typename
BaseType
::
Strides
inStrides
,
const
typename
BaseType
::
Lengths
outLengths
,
const
typename
BaseType
::
Strides
outStrides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
override
final
{
return
std
::
make_unique
<
typename
DerivedDeviceOperator
::
Argument
>
(
inLengths
,
inStrides
,
outLengths
,
outStrides
,
in_dev_buffer
,
out_dev_buffer
,
elementwise_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
final
{
return
std
::
make_unique
<
typename
DerivedDeviceOperator
::
Invoker
>
();
};
// generate other utility methods
template
<
typename
...
Args
>
static
auto
MakeArgument
(
Args
&&
...
args
)
{
static_assert
(
std
::
is_constructible_v
<
typename
DerivedDeviceOperator
::
Argument
,
Args
...
>
);
return
typename
DerivedDeviceOperator
::
Argument
{
std
::
forward
<
Args
>
(
args
)...};
}
static
auto
MakeInvoker
()
noexcept
(
std
::
is_nothrow_default_constructible_v
<
typename
DerivedDeviceOperator
::
Invoker
>
)
{
static_assert
(
std
::
is_default_constructible_v
<
typename
DerivedDeviceOperator
::
Invoker
>
);
return
typename
DerivedDeviceOperator
::
Invoker
{};
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp
0 → 100644
View file @
734a12da
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include <utility>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_permute.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Swap last 2 dimensions
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
// ^^^^^^^^^^^
// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]]
// ^^^^^^^^^^^
template
<
index_t
NumDim
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
index_t
BlockSize
,
index_t
NPerBlock
,
index_t
HPerBlock
,
index_t
WPerBlock
,
index_t
InBlockLdsExtraW
,
typename
InBlockTransferThreadClusterLengths
,
typename
InBlockTransferThreadClusterArrangeOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
>
struct
DevicePermuteImpl
:
DevicePermuteCRTP
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
DevicePermuteImpl
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
BlockSize
,
NPerBlock
,
HPerBlock
,
WPerBlock
,
InBlockLdsExtraW
,
InBlockTransferThreadClusterLengths
,
InBlockTransferThreadClusterArrangeOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
DstScalarPerVector
>>
{
static_assert
(
3
<=
NumDim
,
"Only accept at least 3D dimension tensor"
);
static_assert
((
NumDim
-
2
)
<=
SrcVectorDim
&&
SrcVectorDim
<
NumDim
);
static_assert
((
NumDim
-
2
)
<=
DstVectorDim
&&
DstVectorDim
<
NumDim
);
static_assert
(
SrcVectorDim
!=
DstVectorDim
);
template
<
index_t
N
=
NumDim
>
static
auto
ConvertArrayToTuple
(
const
std
::
array
<
index_t
,
NumDim
>&
array
)
{
static_assert
(
1
<=
N
&&
N
<=
NumDim
);
return
generate_tuple
([
&
](
auto
I
)
{
return
array
[
I
];
},
Number
<
N
>
{});
}
static
auto
MakeDescriptor_N_H_W
(
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
stride
)
{
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
// d[NumDim-1]]
const
auto
desc
=
make_naive_tensor_descriptor
(
ConvertArrayToTuple
(
lengths
),
ConvertArrayToTuple
(
stride
));
// merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2],
// d[NumDim-1]]
// => [N, H, W]
const
index_t
H
=
*
std
::
next
(
rbegin
(
lengths
));
const
index_t
W
=
*
rbegin
(
lengths
);
const
auto
desc_n_h_w
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
ConvertArrayToTuple
<
NumDim
-
2
>
(
lengths
)),
make_pass_through_transform
(
H
),
make_pass_through_transform
(
W
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
NumDim
-
2
>
{}),
Sequence
<
NumDim
-
2
>
{},
Sequence
<
NumDim
-
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
PadTensorDescriptor
(
desc_n_h_w
,
make_tuple
(
NPerBlock
,
HPerBlock
,
WPerBlock
),
Sequence
<
true
,
true
,
true
>
{});
}
using
InGridDesc
=
decltype
(
MakeDescriptor_N_H_W
({
1
,
1
},
{
1
,
1
}));
using
OutGridDesc
=
InGridDesc
;
using
GridwisePermute
=
GridwisePermute
<
InGridDesc
,
OutGridDesc
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
BlockSize
,
NPerBlock
,
HPerBlock
,
WPerBlock
,
InBlockLdsExtraW
,
InBlockTransferThreadClusterLengths
,
InBlockTransferThreadClusterArrangeOrder
,
SrcVectorDim
-
(
NumDim
-
3
),
// calculate new SrcVectorDim for the merged descriptor
DstVectorDim
-
(
NumDim
-
3
),
// calculate new DstVectorDim for the merged descriptor
SrcScalarPerVector
,
DstScalarPerVector
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
inLengths
,
const
std
::
array
<
index_t
,
NumDim
>
inStrides
,
const
std
::
array
<
index_t
,
NumDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDim
>
outStrides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
:
in_dev_buffer_
(
static_cast
<
const
InDataType
*>
(
in_dev_buffer
)),
out_dev_buffer_
(
static_cast
<
OutDataType
*>
(
out_dev_buffer
)),
in_grid_desc_
(
MakeDescriptor_N_H_W
(
inLengths
,
inStrides
)),
out_grid_desc_
(
MakeDescriptor_N_H_W
(
outLengths
,
outStrides
)),
inLengths_
(
inLengths
),
inStrides_
(
inStrides
),
outLengths_
(
outLengths
),
outStrides_
(
outStrides
),
elementwise_op_
(
elementwise_op
),
block_2_tile_map_
(
GridwisePermute
::
MakeDefaultBlock2TileMap
(
in_grid_desc_
))
{
}
const
InDataType
*
in_dev_buffer_
;
OutDataType
*
out_dev_buffer_
;
InGridDesc
in_grid_desc_
;
OutGridDesc
out_grid_desc_
;
std
::
array
<
index_t
,
NumDim
>
inLengths_
;
std
::
array
<
index_t
,
NumDim
>
inStrides_
;
std
::
array
<
index_t
,
NumDim
>
outLengths_
;
std
::
array
<
index_t
,
NumDim
>
outStrides_
;
ElementwiseOperation
elementwise_op_
;
typename
GridwisePermute
::
DefaultBlock2TileMap
block_2_tile_map_
;
};
struct
Invoker
:
BaseInvokerCRTP
<
Invoker
,
Argument
>
{
static
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
index_t
grid_size
=
arg
.
block_2_tile_map_
.
CalculateGridSize
(
arg
.
in_grid_desc_
);
const
auto
kernel
=
kernel_nd_permute
<
GridwisePermute
,
InGridDesc
,
OutGridDesc
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
typename
GridwisePermute
::
DefaultBlock2TileMap
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
in_grid_desc_
,
arg
.
out_grid_desc_
,
arg
.
in_dev_buffer_
,
arg
.
out_dev_buffer_
,
arg
.
elementwise_op_
,
arg
.
block_2_tile_map_
);
return
elapsed_time
;
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
constexpr
auto
GetPaddedLength
=
[](
index_t
length
,
index_t
tile_length
)
{
return
math
::
integer_divide_ceil
(
length
,
tile_length
)
*
tile_length
;
};
constexpr
auto
IsScalarPerVectorValid
=
[](
index_t
length
,
index_t
stride
,
index_t
scalar_per_vector
)
{
if
(
stride
==
1
&&
length
%
scalar_per_vector
==
0
)
{
return
true
;
}
else
if
(
stride
!=
1
&&
scalar_per_vector
==
1
)
{
return
true
;
}
return
false
;
};
return
IsScalarPerVectorValid
(
arg
.
inLengths_
[
SrcVectorDim
],
arg
.
inStrides_
[
SrcVectorDim
],
SrcScalarPerVector
)
&&
IsScalarPerVectorValid
(
GetPaddedLength
(
arg
.
inLengths_
[
SrcVectorDim
],
(
SrcVectorDim
==
NumDim
-
2
?
HPerBlock
:
WPerBlock
)),
arg
.
inStrides_
[
SrcVectorDim
],
SrcScalarPerVector
)
&&
IsScalarPerVectorValid
(
arg
.
outLengths_
[
DstVectorDim
],
arg
.
outStrides_
[
DstVectorDim
],
DstScalarPerVector
)
&&
IsScalarPerVectorValid
(
GetPaddedLength
(
arg
.
outLengths_
[
DstVectorDim
],
(
DstVectorDim
==
NumDim
-
2
?
HPerBlock
:
WPerBlock
)),
arg
.
inStrides_
[
DstVectorDim
],
DstScalarPerVector
)
&&
GridwisePermute
::
CheckValidity
(
arg
.
in_grid_desc_
,
arg
.
out_grid_desc_
);
};
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment