Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
55a89c74
Commit
55a89c74
authored
Dec 16, 2023
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
0dacd895
dcedf363
Changes
61
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
891 additions
and
136 deletions
+891
-136
include/ck/wrapper/utils/tensor_utils.hpp
include/ck/wrapper/utils/tensor_utils.hpp
+290
-0
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+3
-3
library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp
..._instance/gpu/contraction/device_contraction_instance.hpp
+20
-4
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp
...vice_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp
+3
-17
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
...include/ck/library/tensor_operation_instance/gpu/gemm.hpp
+7
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp
...pu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp
+7
-5
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp
...evice_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp
+1
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp
...evice_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp
+1
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp
...m_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp
+26
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp
...mm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp
+26
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
...eadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
+4
-4
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
...leadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
+4
-4
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
...leadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
+4
-4
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
...eadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
+4
-4
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp
...pu/quantization/conv2d_fwd/conv2d_quantization_common.hpp
+3
-3
profiler/src/profile_transpose.cpp
profiler/src/profile_transpose.cpp
+0
-85
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/wrapper/CMakeLists.txt
test/wrapper/CMakeLists.txt
+4
-0
test/wrapper/test_layout.cpp
test/wrapper/test_layout.cpp
+481
-0
No files found.
include/ck/wrapper/utils/tensor_utils.hpp
0 → 100644
View file @
55a89c74
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/dynamic_buffer.hpp"
#include "ck/utility/amd_address_space.hpp"
namespace
ck
{
namespace
wrapper
{
/**
* \brief Memory type, allowed members:
* - Generic,
* - Global,
* - LDS,
* - SGPR,
* - VGPR,
*/
using
MemoryTypeEnum
=
AddressSpaceEnum
;
// Disable from doxygen docs generation
/// @cond
// forward declarations
template
<
typename
Shape
,
typename
Strides
>
struct
Layout
;
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
// params for Register memory
index_t
ScalarPerVector
// param for Register memory
>
struct
Tensor
;
template
<
typename
FromType
,
typename
ToType
>
struct
Slice
{
__host__
__device__
constexpr
Slice
()
:
from_
(),
to_
()
{}
__host__
__device__
constexpr
Slice
(
FromType
from
,
ToType
to
)
:
from_
(
from
),
to_
(
to
)
{}
template
<
typename
T
>
__host__
__device__
constexpr
auto
range
(
const
T
&
dim
)
const
{
if
constexpr
(
is_same_v
<
FromType
,
index_t
>
||
is_same_v
<
ToType
,
index_t
>
||
is_same_v
<
T
,
index_t
>
)
{
assert
(
dim
>=
to_
&&
from_
>=
0
&&
(
to_
<
0
||
to_
>
from_
)
&&
"Invalid range"
);
if
(
to_
<
0
)
{
return
dim
-
from_
+
to_
+
1
;
}
else
{
// workaround if one end of the interval is index_t and the second one is Number
return
static_cast
<
index_t
>
(
to_
)
-
static_cast
<
index_t
>
(
from_
);
}
}
else
{
static_assert
(
dim
>=
to_
&&
from_
>=
Number
<
0
>
{}
&&
(
to_
<
0
||
to_
>
from_
),
"Invalid range"
);
if
constexpr
(
to_
<
0
)
{
return
dim
-
from_
+
to_
+
Number
<
1
>
{};
}
else
{
return
to_
-
from_
;
}
}
}
__host__
__device__
static
constexpr
bool
IsSlice
()
{
return
true
;
}
const
FromType
from_
;
const
ToType
to_
;
};
template
<
typename
T
>
using
is_slice
=
decltype
(
std
::
declval
<
T
&>
().
IsSlice
());
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
/// @endcond
/**
* \brief Make tensor function.
*
* \tparam MemoryType Type of memory.
* \param pointer Pointer to the memory.
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template
<
MemoryTypeEnum
MemoryType
,
typename
ElementType
,
typename
Shape
,
typename
Strides
>
constexpr
auto
make_tensor
(
ElementType
*
pointer
,
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
Tensor
<
MemoryType
,
ElementType
,
Shape
,
Strides
,
0
/*NumVectors*/
,
0
/*ScalarPerVector*/
>
(
pointer
,
layout
);
}
/**
* \brief Make SGPR or VGPR tensor function.
*
* \tparam MemoryType Type of memory.
* \tparam NumVectors Number of vectors.
* \tparam ScalarPerVector Scalars per vector.
* \tparam ElementType Memory data type.
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template
<
MemoryTypeEnum
MemoryType
,
index_t
NumVectors
,
index_t
ScalarPerVector
,
typename
ElementType
,
typename
Shape
,
typename
Strides
>
constexpr
auto
make_register_tensor
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
static_assert
(
!
IsNestedTuple
(
Shape
{}),
"Register tensor with nested layout is not supported"
);
return
Tensor
<
MemoryType
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>
(
layout
);
}
/**
* \brief Get Tensor Layout.
*
* \param tensor Tensor to get layout of.
* \return Requsted layout.
*/
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
const
auto
&
layout
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
{
return
tensor
.
GetLayout
();
}
/**
* \brief Product of tensor shape dims.
*
* \tparam Idxs Indexes to access specific shape dim (optional).
* \param tensor Tensor to get Shape of.
* \return Requsted size.
*/
template
<
index_t
...
Idxs
,
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
index_t
size
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
{
return
size
<
Idxs
...
>
(
tensor
.
GetLayout
());
}
/**
* \brief Rank of Shape tuple.
*
* \tparam Idxs Indexes to access specific shape dim (optional).
* \param tensor Tensor to get rank of.
* \return Requsted rank.
*/
template
<
index_t
...
Idxs
,
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
index_t
rank
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
{
return
rank
<
Idxs
...
>
(
tensor
.
GetLayout
());
}
/**
* \brief Depth of Shape tuple.
*
* \tparam Idxs Indexes to access specific shape dim (optional).
* \param tensor Tensor to get depth of.
* \return Requsted depth.
*/
template
<
index_t
...
Idxs
,
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
index_t
depth
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
{
return
depth
<
Idxs
...
>
(
tensor
.
GetLayout
());
}
/**
* \brief Get Tensor strides.
*
* \param tensor Tensor to get strides from.
* \return Requsted strides.
*/
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
const
auto
&
stride
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
{
return
stride
(
tensor
.
GetLayout
());
}
/**
* \brief Get Tensor shape.
*
* \param tensor Tensor to get shape from.
* \return Requsted shape.
*/
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
const
auto
&
shape
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
{
return
shape
(
tensor
.
GetLayout
());
}
/**
* \brief Get dim slice.
*
* \param from Beginning of the interval.
* \param to End of the interval. (could be also negative to index from the end)
* \return Requested slice. Could be used to create sliced tensor from other tensor.
*/
template
<
typename
FromType
,
typename
ToType
>
constexpr
auto
slice
(
const
FromType
from
,
const
ToType
to
)
{
return
Slice
<
FromType
,
ToType
>
(
from
,
to
);
}
/**
* \brief Get dim slice. (Assumed that from is equal to 1)
*
* \param to End of the interval. (could be also negative to index from the end)
* \return Requested slice. Could be used to create sliced tensor from other tensor.
*/
template
<
typename
ToType
>
constexpr
auto
slice
(
const
ToType
to
)
{
if
constexpr
(
is_same_v
<
ToType
,
index_t
>
)
{
return
Slice
<
index_t
,
ToType
>
(
0
,
to
);
}
else
{
return
Slice
<
Number
<
0
>
,
ToType
>
(
Number
<
0
>
{},
to
);
}
}
/**
* \brief Get whole dim slice (from = 0, to = -1).
*
* \return Requested slice. Could be used to create sliced tensor from other tensor.
*/
constexpr
auto
slice
()
{
return
Slice
<
Number
<
0
>
,
Number
<-
1
>>
(
Number
<
0
>
{},
Number
<-
1
>
{});
}
}
// namespace wrapper
}
// namespace ck
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
55a89c74
...
...
@@ -86,9 +86,9 @@ using NHWGK = ck::tensor_layout::convolution::NHWGK;
using
NDHWGK
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
//
using
GK
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
GK_Tuple
=
ck
::
Tuple
<
GK
>
;
using
GK_GK_Tuple
=
ck
::
Tuple
<
GK
,
GK
>
;
using
G
_
K
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
GK_Tuple
=
ck
::
Tuple
<
G
_
K
>
;
using
GK_GK_Tuple
=
ck
::
Tuple
<
G
_
K
,
G
_
K
>
;
// pointwise functor
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp
View file @
55a89c74
...
...
@@ -61,7 +61,11 @@ using device_contraction_kk_instance = std::tuple<
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
,
ComputeDataType
>
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
,
ComputeDataType
>
,
// Small scalar per vector
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
ComputeDataType
>
// clang-format on
>
;
...
...
@@ -96,7 +100,11 @@ using device_contraction_kn_instance = std::tuple<
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
1
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
1
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
// Small scalar per vector
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
ComputeDataType
>
// clang-format on
>
;
...
...
@@ -131,7 +139,11 @@ using device_contraction_mk_instance = std::tuple<
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
1
,
4
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
1
,
4
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
// Small scalar per vector
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
ComputeDataType
>
// clang-format on
>
;
...
...
@@ -166,7 +178,11 @@ using device_contraction_mn_instance = std::tuple<
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
1
,
1
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
1
,
1
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
,
// Small scalar per vector
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
ComputeDataType
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
ComputeDataType
>
// clang-format on
>
;
...
...
library/
src
/tensor_operation_instance/gpu/
gemm/
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.
c
pp
→
library/
include/ck/library
/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.
h
pp
View file @
55a89c74
...
...
@@ -25,10 +25,6 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
MNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
template
<
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
=
std
::
tuple
<
...
...
@@ -37,7 +33,7 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple<
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
// pipeline v1, 1 wave
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
256
,
128
,
64
,
16
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
128
,
256
,
64
,
16
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
...
...
@@ -75,7 +71,8 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
F32
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
1
,
256
,
64
,
128
,
64
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
,
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
#if 0
//CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>,
...
...
@@ -98,17 +95,6 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple<
// clang-format on
>
;
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
<
GemmDefault
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
<
MNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
View file @
55a89c74
...
...
@@ -345,7 +345,11 @@ void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
...
...
@@ -575,7 +579,8 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp
View file @
55a89c74
...
...
@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
BF16
,
BF16
,
...
...
@@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
F16
,
F16
,
...
...
@@ -59,7 +59,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
F32
,
F32
,
...
...
@@ -75,7 +75,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
int8_t
,
int8_t
,
...
...
@@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
&&
DLayouts
::
Size
()
==
2
&&
is_same_v
<
tuple_element_t
<
0
,
DLayouts
>
,
NDHWGK
>
&&
is_same_v
<
tuple_element_t
<
1
,
DLayouts
>
,
G_K
>
)
{
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
...
...
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
55a89c74
...
...
@@ -101,7 +101,8 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp
)
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp
View file @
55a89c74
...
...
@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp
View file @
55a89c74
...
...
@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp
0 → 100644
View file @
55a89c74
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp"
#ifdef CK_ENABLE_FP8
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
<
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp
0 → 100644
View file @
55a89c74
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp"
#ifdef CK_ENABLE_FP8
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
MNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
<
MNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
View file @
55a89c74
...
...
@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
BF16
,
BF16
,
...
...
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
...
...
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
...
...
@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
View file @
55a89c74
...
...
@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
F16
,
F16
,
...
...
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
...
...
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
...
...
@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
View file @
55a89c74
...
...
@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
F32
,
F32
,
...
...
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
...
...
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
...
...
@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
View file @
55a89c74
...
...
@@ -12,7 +12,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
int8_t
,
int8_t
,
...
...
@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
...
...
@@ -35,7 +35,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
...
...
@@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
...
...
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp
View file @
55a89c74
...
...
@@ -22,13 +22,13 @@ using S = ck::Sequence<Is...>;
using
NHWGC
=
ck
::
tensor_layout
::
convolution
::
NHWGC
;
using
GKYXC
=
ck
::
tensor_layout
::
convolution
::
GKYXC
;
using
NHWGK
=
ck
::
tensor_layout
::
convolution
::
NHWGK
;
using
GK
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
G
_
K
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
TanH
=
ck
::
tensor_operation
::
element_wise
::
TanH
;
using
GK_Tuple
=
ck
::
Tuple
<
GK
>
;
using
GK_GK_Tuple
=
ck
::
Tuple
<
GK
,
GK
>
;
using
GK_Tuple
=
ck
::
Tuple
<
G
_
K
>
;
using
GK_GK_Tuple
=
ck
::
Tuple
<
G
_
K
,
G
_
K
>
;
using
I32_Tuple
=
ck
::
Tuple
<
int32_t
>
;
using
F32_Tuple
=
ck
::
Tuple
<
float
>
;
using
I32_F32_Tuple
=
ck
::
Tuple
<
int32_t
,
float
>
;
...
...
profiler/src/profile_transpose.cpp
deleted
100644 → 0
View file @
0dacd895
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_transpose_impl.hpp"
#include "profiler_operation_registry.hpp"
enum
struct
MatrixLayout
{
NCDHW
,
// 0
NCHWD
,
// 1
};
enum
struct
DataType
{
F32_F32_F32_F32_F32
,
// 0
F16_F16_F16_F16_F16
,
// 1
};
#define OP_NAME "transpose"
#define OP_DESC "Transpose"
int
profile_transpose
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
15
)
{
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
// printf("arg3: matrix layout (NCDHW -> NDCHW);\n");
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg6: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg8 to 13: N, C, D, H, W
\n
"
);
exit
(
1
);
}
const
auto
data_type
=
static_cast
<
DataType
>
(
std
::
stoi
(
argv
[
2
]));
// const auto layout = static_cast<MatrixLayout>(std::stoi(argv[3]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
3
]);
const
int
init_method
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
6
]);
std
::
vector
<
index_t
>
lengths
=
std
::
stoi
(
argv
[
7
]);
/**const int N = std::stoi(argv[7]);
const int C = std::stoi(argv[8]);
const int D = std::stoi(argv[9]);
const int H = std::stoi(argv[10]);
const int W = std::stoi(argv[11]);**/
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
auto
profile
=
[
&
](
auto
a_type
,
auto
b_type
)
{
using
ADataType
=
decltype
(
a_type
);
using
BDataType
=
decltype
(
b_type
);
bool
pass
=
ck
::
profiler
::
profile_transpose_impl
<
ADataType
,
BDataType
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
lengths
);
return
pass
?
0
:
1
;
};
if
(
data_type
==
GemmDataType
::
F32_F32_F32_F32_F32
)
{
return
profile
(
F32
{},
F32
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F16_F16
)
{
return
profile
(
F16
{},
F16
{});
}
else
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
return
1
;
}
}
REGISTER_PROFILER_OPERATION
(
OP_NAME
,
OP_DESC
,
profile_gemm_transpose
);
test/CMakeLists.txt
View file @
55a89c74
...
...
@@ -149,6 +149,7 @@ add_subdirectory(batched_gemm_multi_d)
add_subdirectory
(
grouped_convnd_bwd_data
)
add_subdirectory
(
conv_tensor_rearrange
)
add_subdirectory
(
transpose
)
add_subdirectory
(
wrapper
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
endif
()
test/wrapper/CMakeLists.txt
0 → 100644
View file @
55a89c74
add_gtest_executable
(
test_layout test_layout.cpp
)
target_link_libraries
(
test_layout PRIVATE utility
)
add_gtest_executable
(
test_tensor test_tensor.cpp
)
target_link_libraries
(
test_tensor PRIVATE utility
)
test/wrapper/test_layout.cpp
0 → 100644
View file @
55a89c74
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <vector>
#include <gtest/gtest.h>
#include "ck/utility/common_header.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
class
TestWrapperLayout
:
public
::
testing
::
Test
{
protected:
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
template
<
typename
Desc
,
typename
Desc1d
,
typename
LayoutRuntime
,
typename
LayoutCompiletime
,
typename
Idxs
>
void
Run
(
Desc
&
desc
,
Desc1d
&
desc_1d
,
LayoutRuntime
&
layout_runtime
,
LayoutCompiletime
&
layout_compiletime
,
const
std
::
vector
<
Idxs
>&
idxs
)
{
// 1d check
EXPECT_EQ
(
desc_1d
.
GetLength
(
I0
),
ck
::
wrapper
::
size
(
layout_runtime
));
// Check layout compiletime and runtime result consistency
EXPECT_EQ
(
ck
::
wrapper
::
size
(
layout_runtime
),
ck
::
wrapper
::
size
(
layout_compiletime
));
for
(
ck
::
index_t
i
=
0
;
i
<
desc_1d
.
GetLength
(
I0
);
i
++
)
{
const
ck
::
index_t
layout_runtime_offset_1d
=
layout_runtime
(
ck
::
make_tuple
(
i
));
const
ck
::
index_t
layout_compiletime_offset_1d
=
layout_compiletime
(
ck
::
make_tuple
(
i
));
const
ck
::
index_t
desc_offset_1d
=
desc_1d
.
CalculateOffset
(
ck
::
make_tuple
(
i
));
EXPECT_EQ
(
layout_runtime_offset_1d
,
desc_offset_1d
);
EXPECT_EQ
(
layout_compiletime_offset_1d
,
layout_runtime_offset_1d
);
}
// size(layout)-d check, don't check if access is hierarchical
if
constexpr
(
!
IsNestedTuple
(
Idxs
{}))
{
ck
::
static_for
<
0
,
Idxs
::
Size
(),
1
>
{}([
&
](
auto
d
)
{
EXPECT_EQ
(
desc
.
GetLength
(
ck
::
Number
<
d
>
{}),
ck
::
wrapper
::
size
<
d
>
(
layout_runtime
));
EXPECT_EQ
(
ck
::
wrapper
::
size
<
d
>
(
layout_runtime
),
ck
::
wrapper
::
size
<
d
>
(
layout_compiletime
));
});
}
for
(
const
auto
idx
:
idxs
)
{
const
ck
::
index_t
layout_runtime_offset
=
layout_runtime
(
idx
);
const
ck
::
index_t
layout_compiletime_offset
=
layout_compiletime
(
idx
);
const
ck
::
index_t
desc_offset
=
desc
.
CalculateOffset
(
UnrollNestedTuple
(
idx
));
// Unroll if nested
EXPECT_EQ
(
layout_runtime_offset
,
desc_offset
);
EXPECT_EQ
(
layout_runtime_offset
,
layout_compiletime_offset
);
}
}
};
TEST_F
(
TestWrapperLayout
,
2
d
)
{
// dims:(4, 3) strides:(1, 4)
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
constexpr
ck
::
index_t
s1
=
1
;
constexpr
ck
::
index_t
s0
=
4
;
const
auto
desc
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{}));
// Reverse due to column major
const
auto
desc_1d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d0
,
d1
))),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{}));
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
d1
,
d0
));
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}));
std
::
vector
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>>
idxs
;
for
(
ck
::
index_t
h
=
0
;
h
<
d1
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
;
w
++
)
{
idxs
.
emplace_back
(
h
,
w
);
}
}
this
->
Run
(
desc
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs
);
}
TEST_F
(
TestWrapperLayout
,
3
d_nested
)
{
// dims:((2, 3), 4, 3) strides:((2, 4), 12, 48)
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
constexpr
ck
::
index_t
s3
=
2
;
constexpr
ck
::
index_t
s2
=
4
;
constexpr
ck
::
index_t
s1
=
12
;
constexpr
ck
::
index_t
s0
=
48
;
const
auto
desc
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
ck
::
Number
<
d3
>
{},
ck
::
Number
<
d2
>
{},
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s3
>
{},
ck
::
Number
<
s2
>
{},
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{}));
// Reverse due to column major
const
auto
desc_1d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d0
,
d1
,
d2
,
d3
))),
ck
::
make_tuple
(
ck
::
Sequence
<
3
,
2
,
1
,
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{}));
const
auto
desc_3d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d2
,
d3
)),
ck
::
make_pass_through_transform
(
d1
),
ck
::
make_pass_through_transform
(
d2
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
2
>
{},
ck
::
Sequence
<
3
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
2
>
{}));
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d3
,
d2
),
d1
,
d0
),
ck
::
make_tuple
(
ck
::
make_tuple
(
s3
,
s2
),
s1
,
s0
));
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d3
>
{},
ck
::
Number
<
d2
>
{}),
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}),
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
s3
>
{},
ck
::
Number
<
s2
>
{}),
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{}));
std
::
vector
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
>>
idxs_3d
;
for
(
ck
::
index_t
d
=
0
;
d
<
d2
*
d3
;
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
d1
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
;
w
++
)
{
idxs_3d
.
emplace_back
(
d
,
h
,
w
);
}
}
}
this
->
Run
(
desc_3d
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_3d
);
// Check also 4d iteration
std
::
vector
<
ck
::
Tuple
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>
,
ck
::
index_t
,
ck
::
index_t
>>
idxs_4d
;
for
(
ck
::
index_t
e
=
0
;
e
<
d3
;
e
++
)
{
for
(
ck
::
index_t
d
=
0
;
d
<
d2
;
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
d1
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
;
w
++
)
{
idxs_4d
.
emplace_back
(
ck
::
make_tuple
(
e
,
d
),
h
,
w
);
}
}
}
}
this
->
Run
(
desc
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_4d
);
}
TEST_F
(
TestWrapperLayout
,
2
d_nested
)
{
// dims:((2, 3), (4, 3)) strides:((2, 4), (48, 12))
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
constexpr
ck
::
index_t
s3
=
2
;
constexpr
ck
::
index_t
s2
=
4
;
constexpr
ck
::
index_t
s1
=
48
;
constexpr
ck
::
index_t
s0
=
12
;
const
auto
desc
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
ck
::
Number
<
d3
>
{},
ck
::
Number
<
d2
>
{},
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s3
>
{},
ck
::
Number
<
s2
>
{},
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{}));
// Reverse due to column major
const
auto
desc_1d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d0
,
d1
,
d2
,
d3
))),
ck
::
make_tuple
(
ck
::
Sequence
<
3
,
2
,
1
,
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{}));
const
auto
desc_2d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d2
,
d3
)),
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d0
,
d1
))),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
3
,
2
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d3
,
d2
),
ck
::
make_tuple
(
d1
,
d0
)),
ck
::
make_tuple
(
ck
::
make_tuple
(
s3
,
s2
),
ck
::
make_tuple
(
s1
,
s0
)));
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d3
>
{},
ck
::
Number
<
d2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{})),
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
s3
>
{},
ck
::
Number
<
s2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{})));
std
::
vector
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>>
idxs_2d
;
for
(
ck
::
index_t
h
=
0
;
h
<
d2
*
d3
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
*
d1
;
w
++
)
{
idxs_2d
.
emplace_back
(
h
,
w
);
}
}
this
->
Run
(
desc_2d
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_2d
);
// Check also 4d iteration
std
::
vector
<
ck
::
Tuple
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>
,
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>>>
idxs_4d
;
for
(
ck
::
index_t
e
=
0
;
e
<
d3
;
e
++
)
{
for
(
ck
::
index_t
d
=
0
;
d
<
d2
;
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
d1
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
;
w
++
)
{
idxs_4d
.
emplace_back
(
ck
::
make_tuple
(
e
,
d
),
ck
::
make_tuple
(
h
,
w
));
}
}
}
}
this
->
Run
(
desc
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_4d
);
}
TEST_F
(
TestWrapperLayout
,
3
d_double_nested
)
{
// dims:(((2, 2), 3), (4, 3)) strides:(((2, 4), 8), (96, 24))
constexpr
ck
::
index_t
d4
=
2
;
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
constexpr
ck
::
index_t
s4
=
2
;
constexpr
ck
::
index_t
s3
=
4
;
constexpr
ck
::
index_t
s2
=
8
;
constexpr
ck
::
index_t
s1
=
96
;
constexpr
ck
::
index_t
s0
=
24
;
const
auto
desc
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
d3
>
{},
ck
::
Number
<
d2
>
{},
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s4
>
{},
ck
::
Number
<
s3
>
{},
ck
::
Number
<
s2
>
{},
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{}));
// Reverse due to column major
const
auto
desc_1d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d0
,
d1
,
d2
,
d3
,
d4
))),
ck
::
make_tuple
(
ck
::
Sequence
<
4
,
3
,
2
,
1
,
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{}));
const
auto
desc_3d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d3
,
d4
)),
ck
::
make_pass_through_transform
(
d2
),
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d0
,
d1
))),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
2
>
{},
ck
::
Sequence
<
4
,
3
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
2
>
{}));
const
auto
desc_2d
=
transform_tensor_descriptor
(
desc_3d
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d2
,
d3
*
d4
)),
ck
::
make_pass_through_transform
(
d1
*
d0
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
2
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
),
ck
::
make_tuple
(
d1
,
d0
)),
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
s3
),
s2
),
ck
::
make_tuple
(
s1
,
s0
)));
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
d3
>
{}),
ck
::
Number
<
d2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{})),
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
s3
>
{}),
ck
::
Number
<
s2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{})));
std
::
vector
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>>
idxs_2d
;
for
(
ck
::
index_t
h
=
0
;
h
<
d2
*
d3
*
d4
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
*
d1
;
w
++
)
{
idxs_2d
.
emplace_back
(
h
,
w
);
}
}
this
->
Run
(
desc_2d
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_2d
);
// Check also 3d iteration
std
::
vector
<
ck
::
Tuple
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>
,
ck
::
index_t
>>
idxs_3d
;
for
(
ck
::
index_t
d
=
0
;
d
<
d3
*
d4
;
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
d2
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d1
*
d0
;
w
++
)
{
idxs_3d
.
emplace_back
(
ck
::
make_tuple
(
d
,
h
),
w
);
}
}
}
this
->
Run
(
desc_3d
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_3d
);
// Check also 5d iteration
std
::
vector
<
ck
::
Tuple
<
ck
::
Tuple
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>
,
ck
::
index_t
>
,
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>>>
idxs_5d
;
for
(
ck
::
index_t
f
=
0
;
f
<
d4
;
f
++
)
{
for
(
ck
::
index_t
e
=
0
;
e
<
d3
;
e
++
)
{
for
(
ck
::
index_t
d
=
0
;
d
<
d2
;
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
d1
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
;
w
++
)
{
idxs_5d
.
emplace_back
(
ck
::
make_tuple
(
ck
::
make_tuple
(
f
,
e
),
d
),
ck
::
make_tuple
(
h
,
w
));
}
}
}
}
}
this
->
Run
(
desc
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_5d
);
}
TEST
(
TestLayoutHelpers
,
SizeAndGet
)
{
// dims:(((2, 2), 3), (4, 3))
constexpr
ck
::
index_t
d4
=
2
;
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
),
ck
::
make_tuple
(
d1
,
d0
)));
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
d3
>
{}),
ck
::
Number
<
d2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{})));
// Size of layout
EXPECT_EQ
(
ck
::
wrapper
::
size
(
layout_runtime
),
d4
*
d3
*
d2
*
d1
*
d0
);
EXPECT_EQ
(
ck
::
wrapper
::
size
(
layout_compiletime
),
d4
*
d3
*
d2
*
d1
*
d0
);
// Size of dims
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
layout_runtime
),
d4
*
d3
*
d2
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
layout_compiletime
),
d4
*
d3
*
d2
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
layout_runtime
),
d1
*
d0
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
layout_compiletime
),
d1
*
d0
);
// Access through new layout (using get with layout object)
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_runtime
)),
d4
*
d3
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_compiletime
)),
d4
*
d3
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_runtime
)),
d2
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_compiletime
)),
d2
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_runtime
))),
d4
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_compiletime
))),
d4
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_runtime
))),
d3
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_compiletime
))),
d3
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_runtime
)),
d2
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_compiletime
)),
d2
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
ck
::
wrapper
::
get
<
1
>
(
layout_runtime
)),
d1
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
ck
::
wrapper
::
get
<
1
>
(
layout_compiletime
)),
d1
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
1
>
(
layout_runtime
)),
d0
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
1
>
(
layout_compiletime
)),
d0
);
}
TEST
(
TestLayoutHelpers
,
DepthAndRank
)
{
// dims:(((2, 2), 3), (4, 3))
constexpr
ck
::
index_t
d4
=
2
;
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
),
ck
::
make_tuple
(
d1
,
d0
)));
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
d3
>
{}),
ck
::
Number
<
d2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{})));
EXPECT_EQ
(
ck
::
wrapper
::
depth
(
layout_runtime
),
3
);
EXPECT_EQ
(
ck
::
wrapper
::
depth
(
layout_compiletime
),
3
);
EXPECT_EQ
(
ck
::
wrapper
::
depth
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
)),
2
);
// Check for integer
EXPECT_EQ
(
ck
::
wrapper
::
depth
(
d0
),
0
);
EXPECT_EQ
(
ck
::
wrapper
::
rank
(
layout_runtime
),
2
);
EXPECT_EQ
(
ck
::
wrapper
::
rank
(
layout_compiletime
),
2
);
EXPECT_EQ
(
ck
::
wrapper
::
rank
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
)),
2
);
// Check for integer
EXPECT_EQ
(
ck
::
wrapper
::
rank
(
d0
),
1
);
}
TEST
(
TestLayoutHelpers
,
ShapeAndStrides
)
{
// dims:(((2, 2), 3), (4, 3))
constexpr
ck
::
index_t
d4
=
2
;
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
constexpr
ck
::
index_t
s4
=
2
;
constexpr
ck
::
index_t
s3
=
4
;
constexpr
ck
::
index_t
s2
=
8
;
constexpr
ck
::
index_t
s1
=
96
;
constexpr
ck
::
index_t
s0
=
24
;
const
auto
shape_compiletime
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
d3
>
{}),
ck
::
Number
<
d2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}));
const
auto
strides_compiletime
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
s4
>
{},
ck
::
Number
<
s3
>
{}),
ck
::
Number
<
s2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{}));
const
auto
shape_runtime
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
),
ck
::
make_tuple
(
d1
,
d0
));
const
auto
strides_runtime
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
s4
,
s3
),
s2
),
ck
::
make_tuple
(
s1
,
s0
));
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
shape_runtime
,
strides_runtime
);
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
shape_compiletime
,
strides_compiletime
);
constexpr
bool
check_compiletime_shape
=
std
::
is_same_v
<
decltype
(
shape_compiletime
),
std
::
remove_reference_t
<
decltype
(
shape
(
layout_compiletime
))
>>
;
constexpr
bool
check_compiletime_strides
=
std
::
is_same_v
<
decltype
(
strides_compiletime
),
std
::
remove_reference_t
<
decltype
(
stride
(
layout_compiletime
))
>>
;
constexpr
bool
check_runtime_shape
=
std
::
is_same_v
<
decltype
(
shape_runtime
),
std
::
remove_reference_t
<
decltype
(
shape
(
layout_runtime
))
>>
;
constexpr
bool
check_runtime_strides
=
std
::
is_same_v
<
decltype
(
strides_runtime
),
std
::
remove_reference_t
<
decltype
(
stride
(
layout_runtime
))
>>
;
EXPECT_TRUE
(
check_compiletime_shape
);
EXPECT_TRUE
(
check_compiletime_strides
);
EXPECT_TRUE
(
check_runtime_shape
);
EXPECT_TRUE
(
check_runtime_strides
);
}
TEST
(
TestLayoutHelpers
,
Hierarchical
)
{
// dims:(((2, 2), 3), (4, 3))
constexpr
ck
::
index_t
d4
=
2
;
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
const
auto
runtime_shape
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
),
ck
::
make_tuple
(
d1
,
d0
));
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
runtime_shape
);
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
d3
>
{}),
ck
::
Number
<
d2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{})));
EXPECT_EQ
((
ck
::
wrapper
::
rank
<
0
,
0
>
(
runtime_shape
)),
2
);
EXPECT_EQ
((
ck
::
wrapper
::
rank
<
0
,
0
>
(
layout_runtime
)),
2
);
EXPECT_EQ
((
ck
::
wrapper
::
rank
<
0
,
0
>
(
layout_compiletime
)),
2
);
EXPECT_EQ
((
ck
::
wrapper
::
depth
<
0
,
0
>
(
runtime_shape
)),
1
);
EXPECT_EQ
((
ck
::
wrapper
::
depth
<
0
,
0
>
(
layout_runtime
)),
1
);
EXPECT_EQ
((
ck
::
wrapper
::
depth
<
0
,
0
>
(
layout_compiletime
)),
1
);
EXPECT_EQ
((
ck
::
wrapper
::
size
<
0
,
0
>
(
runtime_shape
)),
d4
*
d3
);
EXPECT_EQ
((
ck
::
wrapper
::
size
<
0
,
0
>
(
layout_runtime
)),
d4
*
d3
);
EXPECT_EQ
((
ck
::
wrapper
::
size
<
0
,
0
>
(
layout_compiletime
)),
d4
*
d3
);
EXPECT_EQ
((
ck
::
wrapper
::
get
<
0
,
0
,
0
>
(
runtime_shape
)),
d4
);
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment