Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
017fb2eb
Commit
017fb2eb
authored
Dec 14, 2023
by
muozturk
Browse files
cmake list
parents
7abb7439
3a3b98ef
Changes
119
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
617 additions
and
65 deletions
+617
-65
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp
...evice_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp
+1
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp
...m_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp
+26
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp
...mm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp
+26
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
...shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
+15
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp
...shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp
...shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp
...shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp
...shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp
+2
-3
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
...wd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
+37
-36
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
...eadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
+4
-4
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
...leadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
+4
-4
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
...leadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
+4
-4
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
...eadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
+4
-4
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp
...pu/quantization/conv2d_fwd/conv2d_quantization_common.hpp
+3
-3
profiler/include/profiler/profile_gemm_impl.hpp
profiler/include/profiler/profile_gemm_impl.hpp
+1
-1
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/transpose/test_transpose_ut_cases.inc
test/transpose/test_transpose_ut_cases.inc
+0
-2
test/wrapper/CMakeLists.txt
test/wrapper/CMakeLists.txt
+2
-0
test/wrapper/test_layout.cpp
test/wrapper/test_layout.cpp
+481
-0
No files found.
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp
View file @
017fb2eb
...
...
@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp
0 → 100644
View file @
017fb2eb
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp"
#ifdef CK_ENABLE_FP8
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
<
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp
0 → 100644
View file @
017fb2eb
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp"
#ifdef CK_ENABLE_FP8
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
MNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances
<
MNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
View file @
017fb2eb
...
...
@@ -35,7 +35,21 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances =
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
32
,
32
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
1
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
1
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
2
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
2
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
0
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
0
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
64
,
32
,
32
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
1
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
2
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
2
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
2
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
S
<
4
,
16
,
4
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp
View file @
017fb2eb
...
...
@@ -32,7 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp
View file @
017fb2eb
...
...
@@ -32,7 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp
View file @
017fb2eb
...
...
@@ -31,7 +31,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp
View file @
017fb2eb
...
...
@@ -24,8 +24,7 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
...
...
@@ -34,7 +33,7 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances =
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
Gemm
MNPadding
,
1
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
Gemm
Default
,
2
,
256
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
S
<
4
,
8
,
8
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
View file @
017fb2eb
...
...
@@ -9,42 +9,43 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ck
::
Tuple
<
BF16
,
BF16
>
,
ck
::
Tuple
<
BF16
,
BF16
>
,
ck
::
Tuple
<>
,
BF16
,
ScaleAdd
,
ScaleAdd
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
// TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347
// void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
// NDHWGC,
// GKZYXC,
// ck::Tuple<>,
// NDHWGK,
// ck::Tuple<BF16, BF16>,
// ck::Tuple<BF16, BF16>,
// ck::Tuple<>,
// BF16,
// ScaleAdd,
// ScaleAdd,
// PassThrough>>>& instances)
// {
// add_device_operation_instances(
// instances,
// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3,
// NDHWGC,
// GKZYXC,
// NDHWGK,
// ConvFwdDefault>{});
// add_device_operation_instances(
// instances,
// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3,
// NDHWGC,
// GKZYXC,
// NDHWGK,
// ConvFwd1x1P0>{});
// add_device_operation_instances(
// instances,
// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3,
// NDHWGC,
// GKZYXC,
// NDHWGK,
// ConvFwd1x1S1P0>{});
// }
}
// namespace instance
}
// namespace device
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
View file @
017fb2eb
...
...
@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
BF16
,
BF16
,
...
...
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
...
...
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
...
...
@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
View file @
017fb2eb
...
...
@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
F16
,
F16
,
...
...
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
...
...
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
...
...
@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
View file @
017fb2eb
...
...
@@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
F32
,
F32
,
...
...
@@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
...
...
@@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
...
...
@@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
View file @
017fb2eb
...
...
@@ -12,7 +12,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
int8_t
,
int8_t
,
...
...
@@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
...
...
@@ -35,7 +35,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
...
...
@@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw
device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
,
NDHW
GK
>
,
ck
::
Tuple
<
NDHWGK
,
G
_
K
>
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
...
...
library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp
View file @
017fb2eb
...
...
@@ -22,13 +22,13 @@ using S = ck::Sequence<Is...>;
using
NHWGC
=
ck
::
tensor_layout
::
convolution
::
NHWGC
;
using
GKYXC
=
ck
::
tensor_layout
::
convolution
::
GKYXC
;
using
NHWGK
=
ck
::
tensor_layout
::
convolution
::
NHWGK
;
using
GK
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
G
_
K
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
TanH
=
ck
::
tensor_operation
::
element_wise
::
TanH
;
using
GK_Tuple
=
ck
::
Tuple
<
GK
>
;
using
GK_GK_Tuple
=
ck
::
Tuple
<
GK
,
GK
>
;
using
GK_Tuple
=
ck
::
Tuple
<
G
_
K
>
;
using
GK_GK_Tuple
=
ck
::
Tuple
<
G
_
K
,
G
_
K
>
;
using
I32_Tuple
=
ck
::
Tuple
<
int32_t
>
;
using
F32_Tuple
=
ck
::
Tuple
<
float
>
;
using
I32_F32_Tuple
=
ck
::
Tuple
<
int32_t
,
float
>
;
...
...
profiler/include/profiler/profile_gemm_impl.hpp
View file @
017fb2eb
...
...
@@ -166,7 +166,7 @@ int profile_gemm_impl(int do_verification,
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
0
,
10
,
50
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
...
...
test/CMakeLists.txt
View file @
017fb2eb
...
...
@@ -151,6 +151,7 @@ add_subdirectory(conv_tensor_rearrange)
add_subdirectory
(
transpose
)
add_subdirectory
(
complex_contraction_bilinear
)
add_subdirectory
(
complex_contraction_scale
)
add_subdirectory
(
wrapper
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
endif
()
test/transpose/test_transpose_ut_cases.inc
View file @
017fb2eb
...
...
@@ -14,7 +14,6 @@ TYPED_TEST(TestTranspose, Test1)
this
->
Run
();
}
TYPED_TEST
(
TestTranpose
,
Test2
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
...
...
@@ -27,4 +26,3 @@ TYPED_TEST(TestTranpose, Test2)
this
->
Run
();
}
test/wrapper/CMakeLists.txt
0 → 100644
View file @
017fb2eb
add_gtest_executable
(
test_layout test_layout.cpp
)
target_link_libraries
(
test_layout PRIVATE utility
)
test/wrapper/test_layout.cpp
0 → 100644
View file @
017fb2eb
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <vector>
#include <gtest/gtest.h>
#include "ck/utility/common_header.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
class
TestWrapperLayout
:
public
::
testing
::
Test
{
protected:
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
template
<
typename
Desc
,
typename
Desc1d
,
typename
LayoutRuntime
,
typename
LayoutCompiletime
,
typename
Idxs
>
void
Run
(
Desc
&
desc
,
Desc1d
&
desc_1d
,
LayoutRuntime
&
layout_runtime
,
LayoutCompiletime
&
layout_compiletime
,
const
std
::
vector
<
Idxs
>&
idxs
)
{
// 1d check
EXPECT_EQ
(
desc_1d
.
GetLength
(
I0
),
ck
::
wrapper
::
size
(
layout_runtime
));
// Check layout compiletime and runtime result consistency
EXPECT_EQ
(
ck
::
wrapper
::
size
(
layout_runtime
),
ck
::
wrapper
::
size
(
layout_compiletime
));
for
(
ck
::
index_t
i
=
0
;
i
<
desc_1d
.
GetLength
(
I0
);
i
++
)
{
const
ck
::
index_t
layout_runtime_offset_1d
=
layout_runtime
(
ck
::
make_tuple
(
i
));
const
ck
::
index_t
layout_compiletime_offset_1d
=
layout_compiletime
(
ck
::
make_tuple
(
i
));
const
ck
::
index_t
desc_offset_1d
=
desc_1d
.
CalculateOffset
(
ck
::
make_tuple
(
i
));
EXPECT_EQ
(
layout_runtime_offset_1d
,
desc_offset_1d
);
EXPECT_EQ
(
layout_compiletime_offset_1d
,
layout_runtime_offset_1d
);
}
// size(layout)-d check, don't check if access is hierarchical
if
constexpr
(
!
IsNestedTuple
(
Idxs
{}))
{
ck
::
static_for
<
0
,
Idxs
::
Size
(),
1
>
{}([
&
](
auto
d
)
{
EXPECT_EQ
(
desc
.
GetLength
(
ck
::
Number
<
d
>
{}),
ck
::
wrapper
::
size
<
d
>
(
layout_runtime
));
EXPECT_EQ
(
ck
::
wrapper
::
size
<
d
>
(
layout_runtime
),
ck
::
wrapper
::
size
<
d
>
(
layout_compiletime
));
});
}
for
(
const
auto
idx
:
idxs
)
{
const
ck
::
index_t
layout_runtime_offset
=
layout_runtime
(
idx
);
const
ck
::
index_t
layout_compiletime_offset
=
layout_compiletime
(
idx
);
const
ck
::
index_t
desc_offset
=
desc
.
CalculateOffset
(
UnrollNestedTuple
(
idx
));
// Unroll if nested
EXPECT_EQ
(
layout_runtime_offset
,
desc_offset
);
EXPECT_EQ
(
layout_runtime_offset
,
layout_compiletime_offset
);
}
}
};
TEST_F
(
TestWrapperLayout
,
2
d
)
{
// dims:(4, 3) strides:(1, 4)
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
constexpr
ck
::
index_t
s1
=
1
;
constexpr
ck
::
index_t
s0
=
4
;
const
auto
desc
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{}));
// Reverse due to column major
const
auto
desc_1d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d0
,
d1
))),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{}));
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
d1
,
d0
));
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}));
std
::
vector
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>>
idxs
;
for
(
ck
::
index_t
h
=
0
;
h
<
d1
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
;
w
++
)
{
idxs
.
emplace_back
(
h
,
w
);
}
}
this
->
Run
(
desc
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs
);
}
TEST_F
(
TestWrapperLayout
,
3
d_nested
)
{
// dims:((2, 3), 4, 3) strides:((2, 4), 12, 48)
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
constexpr
ck
::
index_t
s3
=
2
;
constexpr
ck
::
index_t
s2
=
4
;
constexpr
ck
::
index_t
s1
=
12
;
constexpr
ck
::
index_t
s0
=
48
;
const
auto
desc
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
ck
::
Number
<
d3
>
{},
ck
::
Number
<
d2
>
{},
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s3
>
{},
ck
::
Number
<
s2
>
{},
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{}));
// Reverse due to column major
const
auto
desc_1d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d0
,
d1
,
d2
,
d3
))),
ck
::
make_tuple
(
ck
::
Sequence
<
3
,
2
,
1
,
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{}));
const
auto
desc_3d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d2
,
d3
)),
ck
::
make_pass_through_transform
(
d1
),
ck
::
make_pass_through_transform
(
d2
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
2
>
{},
ck
::
Sequence
<
3
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
2
>
{}));
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d3
,
d2
),
d1
,
d0
),
ck
::
make_tuple
(
ck
::
make_tuple
(
s3
,
s2
),
s1
,
s0
));
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d3
>
{},
ck
::
Number
<
d2
>
{}),
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}),
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
s3
>
{},
ck
::
Number
<
s2
>
{}),
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{}));
std
::
vector
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
>>
idxs_3d
;
for
(
ck
::
index_t
d
=
0
;
d
<
d2
*
d3
;
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
d1
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
;
w
++
)
{
idxs_3d
.
emplace_back
(
d
,
h
,
w
);
}
}
}
this
->
Run
(
desc_3d
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_3d
);
// Check also 4d iteration
std
::
vector
<
ck
::
Tuple
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>
,
ck
::
index_t
,
ck
::
index_t
>>
idxs_4d
;
for
(
ck
::
index_t
e
=
0
;
e
<
d3
;
e
++
)
{
for
(
ck
::
index_t
d
=
0
;
d
<
d2
;
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
d1
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
;
w
++
)
{
idxs_4d
.
emplace_back
(
ck
::
make_tuple
(
e
,
d
),
h
,
w
);
}
}
}
}
this
->
Run
(
desc
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_4d
);
}
TEST_F
(
TestWrapperLayout
,
2
d_nested
)
{
// dims:((2, 3), (4, 3)) strides:((2, 4), (48, 12))
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
constexpr
ck
::
index_t
s3
=
2
;
constexpr
ck
::
index_t
s2
=
4
;
constexpr
ck
::
index_t
s1
=
48
;
constexpr
ck
::
index_t
s0
=
12
;
const
auto
desc
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
ck
::
Number
<
d3
>
{},
ck
::
Number
<
d2
>
{},
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s3
>
{},
ck
::
Number
<
s2
>
{},
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{}));
// Reverse due to column major
const
auto
desc_1d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d0
,
d1
,
d2
,
d3
))),
ck
::
make_tuple
(
ck
::
Sequence
<
3
,
2
,
1
,
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{}));
const
auto
desc_2d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d2
,
d3
)),
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d0
,
d1
))),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
3
,
2
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d3
,
d2
),
ck
::
make_tuple
(
d1
,
d0
)),
ck
::
make_tuple
(
ck
::
make_tuple
(
s3
,
s2
),
ck
::
make_tuple
(
s1
,
s0
)));
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d3
>
{},
ck
::
Number
<
d2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{})),
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
s3
>
{},
ck
::
Number
<
s2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{})));
std
::
vector
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>>
idxs_2d
;
for
(
ck
::
index_t
h
=
0
;
h
<
d2
*
d3
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
*
d1
;
w
++
)
{
idxs_2d
.
emplace_back
(
h
,
w
);
}
}
this
->
Run
(
desc_2d
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_2d
);
// Check also 4d iteration
std
::
vector
<
ck
::
Tuple
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>
,
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>>>
idxs_4d
;
for
(
ck
::
index_t
e
=
0
;
e
<
d3
;
e
++
)
{
for
(
ck
::
index_t
d
=
0
;
d
<
d2
;
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
d1
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
;
w
++
)
{
idxs_4d
.
emplace_back
(
ck
::
make_tuple
(
e
,
d
),
ck
::
make_tuple
(
h
,
w
));
}
}
}
}
this
->
Run
(
desc
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_4d
);
}
TEST_F
(
TestWrapperLayout
,
3
d_double_nested
)
{
// dims:(((2, 2), 3), (4, 3)) strides:(((2, 4), 8), (96, 24))
constexpr
ck
::
index_t
d4
=
2
;
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
constexpr
ck
::
index_t
s4
=
2
;
constexpr
ck
::
index_t
s3
=
4
;
constexpr
ck
::
index_t
s2
=
8
;
constexpr
ck
::
index_t
s1
=
96
;
constexpr
ck
::
index_t
s0
=
24
;
const
auto
desc
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
d3
>
{},
ck
::
Number
<
d2
>
{},
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s4
>
{},
ck
::
Number
<
s3
>
{},
ck
::
Number
<
s2
>
{},
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{}));
// Reverse due to column major
const
auto
desc_1d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d0
,
d1
,
d2
,
d3
,
d4
))),
ck
::
make_tuple
(
ck
::
Sequence
<
4
,
3
,
2
,
1
,
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{}));
const
auto
desc_3d
=
transform_tensor_descriptor
(
desc
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d3
,
d4
)),
ck
::
make_pass_through_transform
(
d2
),
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d0
,
d1
))),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
2
>
{},
ck
::
Sequence
<
4
,
3
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
2
>
{}));
const
auto
desc_2d
=
transform_tensor_descriptor
(
desc_3d
,
ck
::
make_tuple
(
ck
::
make_merge_transform
(
ck
::
make_tuple
(
d2
,
d3
*
d4
)),
ck
::
make_pass_through_transform
(
d1
*
d0
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
,
0
>
{},
ck
::
Sequence
<
2
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
),
ck
::
make_tuple
(
d1
,
d0
)),
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
s3
),
s2
),
ck
::
make_tuple
(
s1
,
s0
)));
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
d3
>
{}),
ck
::
Number
<
d2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{})),
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
s3
>
{}),
ck
::
Number
<
s2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{})));
std
::
vector
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>>
idxs_2d
;
for
(
ck
::
index_t
h
=
0
;
h
<
d2
*
d3
*
d4
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
*
d1
;
w
++
)
{
idxs_2d
.
emplace_back
(
h
,
w
);
}
}
this
->
Run
(
desc_2d
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_2d
);
// Check also 3d iteration
std
::
vector
<
ck
::
Tuple
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>
,
ck
::
index_t
>>
idxs_3d
;
for
(
ck
::
index_t
d
=
0
;
d
<
d3
*
d4
;
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
d2
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d1
*
d0
;
w
++
)
{
idxs_3d
.
emplace_back
(
ck
::
make_tuple
(
d
,
h
),
w
);
}
}
}
this
->
Run
(
desc_3d
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_3d
);
// Check also 5d iteration
std
::
vector
<
ck
::
Tuple
<
ck
::
Tuple
<
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>
,
ck
::
index_t
>
,
ck
::
Tuple
<
ck
::
index_t
,
ck
::
index_t
>>>
idxs_5d
;
for
(
ck
::
index_t
f
=
0
;
f
<
d4
;
f
++
)
{
for
(
ck
::
index_t
e
=
0
;
e
<
d3
;
e
++
)
{
for
(
ck
::
index_t
d
=
0
;
d
<
d2
;
d
++
)
{
for
(
ck
::
index_t
h
=
0
;
h
<
d1
;
h
++
)
{
for
(
ck
::
index_t
w
=
0
;
w
<
d0
;
w
++
)
{
idxs_5d
.
emplace_back
(
ck
::
make_tuple
(
ck
::
make_tuple
(
f
,
e
),
d
),
ck
::
make_tuple
(
h
,
w
));
}
}
}
}
}
this
->
Run
(
desc
,
desc_1d
,
layout_runtime
,
layout_compiletime
,
idxs_5d
);
}
TEST
(
TestLayoutHelpers
,
SizeAndGet
)
{
// dims:(((2, 2), 3), (4, 3))
constexpr
ck
::
index_t
d4
=
2
;
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
),
ck
::
make_tuple
(
d1
,
d0
)));
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
d3
>
{}),
ck
::
Number
<
d2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{})));
// Size of layout
EXPECT_EQ
(
ck
::
wrapper
::
size
(
layout_runtime
),
d4
*
d3
*
d2
*
d1
*
d0
);
EXPECT_EQ
(
ck
::
wrapper
::
size
(
layout_compiletime
),
d4
*
d3
*
d2
*
d1
*
d0
);
// Size of dims
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
layout_runtime
),
d4
*
d3
*
d2
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
layout_compiletime
),
d4
*
d3
*
d2
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
layout_runtime
),
d1
*
d0
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
layout_compiletime
),
d1
*
d0
);
// Access through new layout (using get with layout object)
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_runtime
)),
d4
*
d3
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_compiletime
)),
d4
*
d3
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_runtime
)),
d2
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_compiletime
)),
d2
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_runtime
))),
d4
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_compiletime
))),
d4
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_runtime
))),
d3
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
0
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_compiletime
))),
d3
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_runtime
)),
d2
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
0
>
(
layout_compiletime
)),
d2
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
ck
::
wrapper
::
get
<
1
>
(
layout_runtime
)),
d1
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
0
>
(
ck
::
wrapper
::
get
<
1
>
(
layout_compiletime
)),
d1
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
1
>
(
layout_runtime
)),
d0
);
EXPECT_EQ
(
ck
::
wrapper
::
size
<
1
>
(
ck
::
wrapper
::
get
<
1
>
(
layout_compiletime
)),
d0
);
}
TEST
(
TestLayoutHelpers
,
DepthAndRank
)
{
// dims:(((2, 2), 3), (4, 3))
constexpr
ck
::
index_t
d4
=
2
;
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
),
ck
::
make_tuple
(
d1
,
d0
)));
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
d3
>
{}),
ck
::
Number
<
d2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{})));
EXPECT_EQ
(
ck
::
wrapper
::
depth
(
layout_runtime
),
3
);
EXPECT_EQ
(
ck
::
wrapper
::
depth
(
layout_compiletime
),
3
);
EXPECT_EQ
(
ck
::
wrapper
::
depth
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
)),
2
);
// Check for integer
EXPECT_EQ
(
ck
::
wrapper
::
depth
(
d0
),
0
);
EXPECT_EQ
(
ck
::
wrapper
::
rank
(
layout_runtime
),
2
);
EXPECT_EQ
(
ck
::
wrapper
::
rank
(
layout_compiletime
),
2
);
EXPECT_EQ
(
ck
::
wrapper
::
rank
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
)),
2
);
// Check for integer
EXPECT_EQ
(
ck
::
wrapper
::
rank
(
d0
),
1
);
}
TEST
(
TestLayoutHelpers
,
ShapeAndStrides
)
{
// dims:(((2, 2), 3), (4, 3))
constexpr
ck
::
index_t
d4
=
2
;
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
constexpr
ck
::
index_t
s4
=
2
;
constexpr
ck
::
index_t
s3
=
4
;
constexpr
ck
::
index_t
s2
=
8
;
constexpr
ck
::
index_t
s1
=
96
;
constexpr
ck
::
index_t
s0
=
24
;
const
auto
shape_compiletime
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
d3
>
{}),
ck
::
Number
<
d2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{}));
const
auto
strides_compiletime
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
s4
>
{},
ck
::
Number
<
s3
>
{}),
ck
::
Number
<
s2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
s1
>
{},
ck
::
Number
<
s0
>
{}));
const
auto
shape_runtime
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
),
ck
::
make_tuple
(
d1
,
d0
));
const
auto
strides_runtime
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
s4
,
s3
),
s2
),
ck
::
make_tuple
(
s1
,
s0
));
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
shape_runtime
,
strides_runtime
);
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
shape_compiletime
,
strides_compiletime
);
constexpr
bool
check_compiletime_shape
=
std
::
is_same_v
<
std
::
remove_const
<
decltype
(
shape_compiletime
)
>::
type
,
decltype
(
shape
(
layout_compiletime
))
>
;
constexpr
bool
check_compiletime_strides
=
std
::
is_same_v
<
std
::
remove_const
<
decltype
(
strides_compiletime
)
>::
type
,
decltype
(
stride
(
layout_compiletime
))
>
;
constexpr
bool
check_runtime_shape
=
std
::
is_same_v
<
std
::
remove_const
<
decltype
(
shape_runtime
)
>::
type
,
decltype
(
shape
(
layout_runtime
))
>
;
constexpr
bool
check_runtime_strides
=
std
::
is_same_v
<
std
::
remove_const
<
decltype
(
strides_runtime
)
>::
type
,
decltype
(
stride
(
layout_runtime
))
>
;
EXPECT_TRUE
(
check_compiletime_shape
);
EXPECT_TRUE
(
check_compiletime_strides
);
EXPECT_TRUE
(
check_runtime_shape
);
EXPECT_TRUE
(
check_runtime_strides
);
}
TEST
(
TestLayoutHelpers
,
Hierarchical
)
{
// dims:(((2, 2), 3), (4, 3))
constexpr
ck
::
index_t
d4
=
2
;
constexpr
ck
::
index_t
d3
=
2
;
constexpr
ck
::
index_t
d2
=
3
;
constexpr
ck
::
index_t
d1
=
4
;
constexpr
ck
::
index_t
d0
=
3
;
const
auto
runtime_shape
=
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
d4
,
d3
),
d2
),
ck
::
make_tuple
(
d1
,
d0
));
const
auto
layout_runtime
=
ck
::
wrapper
::
make_layout
(
runtime_shape
);
const
auto
layout_compiletime
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
make_tuple
(
ck
::
Number
<
d4
>
{},
ck
::
Number
<
d3
>
{}),
ck
::
Number
<
d2
>
{}),
ck
::
make_tuple
(
ck
::
Number
<
d1
>
{},
ck
::
Number
<
d0
>
{})));
EXPECT_EQ
((
ck
::
wrapper
::
rank
<
0
,
0
>
(
runtime_shape
)),
2
);
EXPECT_EQ
((
ck
::
wrapper
::
rank
<
0
,
0
>
(
layout_runtime
)),
2
);
EXPECT_EQ
((
ck
::
wrapper
::
rank
<
0
,
0
>
(
layout_compiletime
)),
2
);
EXPECT_EQ
((
ck
::
wrapper
::
depth
<
0
,
0
>
(
runtime_shape
)),
1
);
EXPECT_EQ
((
ck
::
wrapper
::
depth
<
0
,
0
>
(
layout_runtime
)),
1
);
EXPECT_EQ
((
ck
::
wrapper
::
depth
<
0
,
0
>
(
layout_compiletime
)),
1
);
EXPECT_EQ
((
ck
::
wrapper
::
size
<
0
,
0
>
(
runtime_shape
)),
d4
*
d3
);
EXPECT_EQ
((
ck
::
wrapper
::
size
<
0
,
0
>
(
layout_runtime
)),
d4
*
d3
);
EXPECT_EQ
((
ck
::
wrapper
::
size
<
0
,
0
>
(
layout_compiletime
)),
d4
*
d3
);
EXPECT_EQ
((
ck
::
wrapper
::
get
<
0
,
0
,
0
>
(
runtime_shape
)),
d4
);
}
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment