Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dc1c2bf8
Commit
dc1c2bf8
authored
Oct 20, 2024
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into letaoqin/update_layernorm
parents
5cfd751b
a285d6f9
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1406 additions
and
208 deletions
+1406
-208
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+73
-164
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+28
-43
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-0
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+424
-0
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp
...f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp
+1
-1
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+5
-0
test/data_type/test_custom_type.cpp
test/data_type/test_custom_type.cpp
+874
-0
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
dc1c2bf8
...
...
@@ -5,9 +5,8 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/
pipeline/gemm_pipeline
_problem.hpp"
#include "ck_tile/ops/gemm/
block/block_gemm
_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
...
...
@@ -27,20 +26,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QDataType
,
...
...
@@ -66,20 +60,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadHeadDimV
,
Problem
::
kPadHeadDimV
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -104,20 +93,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
OGradDataType
,
...
...
@@ -143,20 +127,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
Problem
::
kPadSeqLenK
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -181,20 +160,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadHeadDimQ
,
Problem
::
kPadSeqLenK
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -222,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
QDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
QDataType
);
...
...
@@ -241,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
KDataType
);
...
...
@@ -260,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
...
@@ -280,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
OGradDataType
);
...
...
@@ -341,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
...
@@ -353,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
return
total_pixels
/
GetAlignmentK
<
Problem
>
();
...
...
@@ -364,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
...
@@ -402,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -425,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -448,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -471,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -842,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsWriteBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
k_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
k_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
k_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
k_block_dstr
=
make_static_tile_distribution
(
k_block_dstr_encode
);
return
k_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKRegBlockDescriptor
()
{
...
...
@@ -891,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
...
...
@@ -916,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsWriteBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kVPack
=
GetSmemKPackV
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kVPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
v_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
return
v_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegBlockDescriptor
()
{
...
...
@@ -966,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
...
...
@@ -992,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -1074,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
...
...
@@ -1118,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -1281,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
// Hold full block data
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
...
...
@@ -1325,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -1885,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static
constexpr
index_t
kN0
=
Problem
::
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
index_t
kK0
=
Problem
::
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK2
=
Problem
::
BlockFmhaShape
::
kK2
;
static
constexpr
index_t
kK4
=
Problem
::
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
WarpGemmM
=
...
...
@@ -1899,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
// Compute
static
constexpr
index_t
Gemm0MFMA
=
kM0
*
kN0
*
kQKHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
kM0
*
kN0
*
kK0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm1MFMA
=
kM0
*
kN0
*
kVHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm2MFMA
=
kN0
*
kVHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm2MFMA
=
kM0
*
kN0
*
kK2
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm3MFMA
=
kN0
*
kQKHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
...
...
@@ -1929,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kM0
*
kQKHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentQ
<
Problem
>
();
static
constexpr
index_t
SGradT_LDS_READ_P1
=
kM0
*
kK4
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kQKHeaddim
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kK0
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
LSE_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
static
constexpr
index_t
SGradT_LDS_READ_P2
=
kM0
*
(
kN0
-
kK4
)
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
OGrad_LDS_READ
=
kM0
*
k
VHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
kM0
*
k
K2
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
D_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
// LDS Write
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
dc1c2bf8
...
...
@@ -5,9 +5,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...
...
@@ -77,20 +77,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
...
@@ -207,20 +202,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
...
@@ -968,20 +958,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
auto
warp_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
include/ck_tile/ops/gemm.hpp
View file @
dc1c2bf8
...
...
@@ -23,6 +23,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
0 → 100644
View file @
dc1c2bf8
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace
ck_tile
{
// UniversalGemm Policy
template
<
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
UniversalGemmPipelineAgBgCrPolicy
{
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
bool
TransposeC
=
true
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayoutA
>::
value
)
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
);
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
*
number
<
MLdsLayer
>
{},
number
<
MPerBlock
/
MLdsLayer
>
{},
K1
),
make_tuple
(
K1
,
number
<
KPerBlock
*
MLdsLayer
>
{},
I1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
K0
*
MLdsLayer
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
a_lds_block_desc_ak0_kMLdsLayer_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_ak0_kMLdsLayer_m_ak1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
K0
,
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
MLdsLayer
>
{}))),
make_tuple
(
sequence
<
0
,
3
>
{},
sequence
<
1
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
a_lds_block_desc_m_k
;
}
else
// ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr
auto
M0
=
get_warp_size
()
*
Problem
::
BlockGemmShape
::
BlockWarps
::
at
(
I0
);
constexpr
auto
M1
=
MPerBlock
/
M0
;
constexpr
auto
KThreadWrite
=
Problem
::
kBlockSize
/
M0
;
constexpr
auto
K0PerThreadWrite
=
K0
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
WarpGemm
::
kM
;
constexpr
auto
K0PerThreadRead
=
K0
/
KThreadRead
;
constexpr
auto
kfold
=
(
K1
*
M0
*
sizeof
(
ADataType
)
>
128
)
?
1
:
128
/
(
K1
*
M0
*
sizeof
(
ADataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=mpair<=kN0
constexpr
auto
mpair
=
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)
>
128
)
?
1
:
((
128
/
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)))
>
M0
?
M0
:
128
/
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)));
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
M1
>
{},
number
<
kfold
*
M0
/
mpair
>
{},
number
<
mpair
>
{},
K1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
M1
>
{},
number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
number
<
mpair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
a_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
M1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
M0
/
mpair
>
{})),
make_pass_through_transform
(
number
<
mpair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
M0
/
mpair
>
{},
number
<
mpair
>
{},
number
<
M1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
a_lds_block_desc_m_k
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
LayoutB
>::
value
)
{
// NLdsLayer * K0 as logical Bank
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
);
;
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
*
number
<
NLdsLayer
>
{},
number
<
NPerBlock
/
NLdsLayer
>
{},
K1
),
make_tuple
(
K1
,
number
<
KPerBlock
*
NLdsLayer
>
{},
I1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
K0
*
NLdsLayer
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
b_lds_block_desc_bk0_kNLdsLayer_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_bk0_kNLdsLayer_n_bk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
K0
,
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
NLdsLayer
>
{}))),
make_tuple
(
sequence
<
0
,
3
>
{},
sequence
<
1
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
b_lds_block_desc_n_k
;
}
else
// RowMajor B
{
constexpr
auto
N0
=
get_warp_size
()
*
Problem
::
BlockGemmShape
::
BlockWarps
::
at
(
I1
);
constexpr
auto
N1
=
NPerBlock
/
N0
;
constexpr
auto
KThreadWrite
=
Problem
::
kBlockSize
/
N0
;
constexpr
auto
K0PerThreadWrite
=
K0
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
WarpGemm
::
kN
;
constexpr
auto
K0PerThreadRead
=
K0
/
KThreadRead
;
constexpr
auto
kfold
=
(
K1
*
N0
*
sizeof
(
BDataType
)
>
128
)
?
1
:
128
/
(
K1
*
N0
*
sizeof
(
BDataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=npair<=kN0
constexpr
auto
npair
=
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)
>
128
)
?
1
:
((
128
/
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)))
>
N0
?
N0
:
128
/
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)));
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{},
number
<
npair
>
{},
K1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
b_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
N1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
N0
/
npair
>
{},
number
<
npair
>
{},
number
<
N1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
b_lds_block_desc_n_k
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_a
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_b
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
index_t
smem_size
=
0
;
smem_size
+=
smem_size_a
+
smem_size_b
;
return
smem_size
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
AccDataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
BlockWarps
,
WarpGemm
>
;
return
BlockGemmASmemBSmemCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
}
};
}
// namespace ck_tile
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp
View file @
dc1c2bf8
...
...
@@ -46,7 +46,7 @@ using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
192
,
256
,
64
,
16
,
8
,
32
,
32
,
3
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
16
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
16
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
//
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
// We prefer following instance, however, existing compiler bug cause it failed to generate sanity code.
// DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
16
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
...
...
test/data_type/CMakeLists.txt
View file @
dc1c2bf8
...
...
@@ -18,4 +18,9 @@ if(result EQUAL 0)
target_link_libraries
(
test_bf8 PRIVATE utility
)
endif
()
add_gtest_executable
(
test_custom_type test_custom_type.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_custom_type PRIVATE utility
)
endif
()
add_gtest_executable
(
test_type_convert_const type_convert_const.cpp
)
test/data_type/test_custom_type.cpp
0 → 100644
View file @
dc1c2bf8
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bf8_t
;
using
ck
::
bhalf_t
;
using
ck
::
f8_t
;
using
ck
::
half_t
;
using
ck
::
Number
;
using
ck
::
type_convert
;
using
ck
::
vector_type
;
TEST
(
Custom_bool
,
TestSize
)
{
struct
custom_bool_t
{
bool
data
;
};
ASSERT_EQ
(
sizeof
(
custom_bool_t
),
sizeof
(
bool
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bool_t
,
2
>
),
sizeof
(
vector_type
<
bool
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bool_t
,
4
>
),
sizeof
(
vector_type
<
bool
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bool_t
,
8
>
),
sizeof
(
vector_type
<
bool
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bool_t
,
16
>
),
sizeof
(
vector_type
<
bool
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bool_t
,
32
>
),
sizeof
(
vector_type
<
bool
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bool_t
,
64
>
),
sizeof
(
vector_type
<
bool
,
64
>
));
}
TEST
(
Custom_bool
,
TestAsType
)
{
struct
custom_bool_t
{
using
type
=
bool
;
type
data
;
custom_bool_t
()
:
data
{
type
{}}
{}
custom_bool_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
bool
>
test_vec
=
{
false
,
true
,
false
,
true
};
// reference vector
vector_type
<
custom_bool_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_bool_t
>()(
Number
<
i
>
{}).
data
,
false
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_bool_t
>()(
Number
<
i
>
{})
=
custom_bool_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_bool_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_bool_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_bool
,
TestAsTypeReshape
)
{
struct
custom_bool_t
{
using
type
=
bool
;
type
data
;
custom_bool_t
()
:
data
{
type
{}}
{}
custom_bool_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
bool
>
test_vec
=
{
false
,
true
,
false
,
true
};
// reference vector
vector_type
<
custom_bool_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_bool_t
>()(
Number
<
i
>
{}).
data
,
false
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_bool_t
>()(
Number
<
i
>
{})
=
custom_bool_t
{
test_vec
.
at
(
i
)};
});
// copy the first half of a vector
vector_type
<
custom_bool_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
custom_bool_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_bool_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_int8
,
TestSize
)
{
struct
custom_int8_t
{
int8_t
data
;
};
ASSERT_EQ
(
sizeof
(
custom_int8_t
),
sizeof
(
int8_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_int8_t
,
2
>
),
sizeof
(
vector_type
<
int8_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_int8_t
,
4
>
),
sizeof
(
vector_type
<
int8_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_int8_t
,
8
>
),
sizeof
(
vector_type
<
int8_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_int8_t
,
16
>
),
sizeof
(
vector_type
<
int8_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_int8_t
,
32
>
),
sizeof
(
vector_type
<
int8_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_int8_t
,
64
>
),
sizeof
(
vector_type
<
int8_t
,
64
>
));
}
TEST
(
Custom_int8
,
TestAsType
)
{
struct
custom_int8_t
{
using
type
=
int8_t
;
type
data
;
custom_int8_t
()
:
data
{
type
{}}
{}
custom_int8_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
int8_t
>
test_vec
=
{
3
,
-
6
,
8
,
-
2
};
// reference vector
vector_type
<
custom_int8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_int8_t
>()(
Number
<
i
>
{}).
data
,
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_int8_t
>()(
Number
<
i
>
{})
=
custom_int8_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_int8_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_int8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_int8
,
TestAsTypeReshape
)
{
struct
custom_int8_t
{
using
type
=
int8_t
;
type
data
;
custom_int8_t
()
:
data
{
type
{}}
{}
custom_int8_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
int8_t
>
test_vec
=
{
3
,
-
6
,
8
,
-
2
};
// reference vector
vector_type
<
custom_int8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_int8_t
>()(
Number
<
i
>
{}).
data
,
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_int8_t
>()(
Number
<
i
>
{})
=
custom_int8_t
{
test_vec
.
at
(
i
)};
});
// copy the first half of a vector
vector_type
<
custom_int8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
custom_int8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_int8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_uint8
,
TestSize
)
{
struct
custom_uint8_t
{
uint8_t
data
;
};
ASSERT_EQ
(
sizeof
(
custom_uint8_t
),
sizeof
(
uint8_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_uint8_t
,
2
>
),
sizeof
(
vector_type
<
uint8_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_uint8_t
,
4
>
),
sizeof
(
vector_type
<
uint8_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_uint8_t
,
8
>
),
sizeof
(
vector_type
<
uint8_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_uint8_t
,
16
>
),
sizeof
(
vector_type
<
uint8_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_uint8_t
,
32
>
),
sizeof
(
vector_type
<
uint8_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_uint8_t
,
64
>
),
sizeof
(
vector_type
<
uint8_t
,
64
>
));
}
TEST
(
Custom_uint8
,
TestAsType
)
{
struct
custom_uint8_t
{
using
type
=
uint8_t
;
type
data
;
custom_uint8_t
()
:
data
{
type
{}}
{}
custom_uint8_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
uint8_t
>
test_vec
=
{
3
,
6
,
8
,
2
};
// reference vector
vector_type
<
custom_uint8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_uint8_t
>()(
Number
<
i
>
{}).
data
,
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_uint8_t
>()(
Number
<
i
>
{})
=
custom_uint8_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_uint8_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_uint8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_uint8
,
TestAsTypeReshape
)
{
struct
custom_uint8_t
{
using
type
=
uint8_t
;
type
data
;
custom_uint8_t
()
:
data
{
type
{}}
{}
custom_uint8_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
uint8_t
>
test_vec
=
{
3
,
6
,
8
,
2
};
// reference vector
vector_type
<
custom_uint8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_uint8_t
>()(
Number
<
i
>
{}).
data
,
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_uint8_t
>()(
Number
<
i
>
{})
=
custom_uint8_t
{
test_vec
.
at
(
i
)};
});
// copy the first half of a vector
vector_type
<
custom_uint8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
custom_uint8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_uint8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_f8
,
TestSize
)
{
struct
custom_f8_t
{
_BitInt
(
8
)
data
;
};
ASSERT_EQ
(
sizeof
(
custom_f8_t
),
sizeof
(
_BitInt
(
8
)));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_f8_t
,
2
>
),
sizeof
(
vector_type
<
_BitInt
(
8
),
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_f8_t
,
4
>
),
sizeof
(
vector_type
<
_BitInt
(
8
),
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_f8_t
,
8
>
),
sizeof
(
vector_type
<
_BitInt
(
8
),
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_f8_t
,
16
>
),
sizeof
(
vector_type
<
_BitInt
(
8
),
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_f8_t
,
32
>
),
sizeof
(
vector_type
<
_BitInt
(
8
),
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_f8_t
,
64
>
),
sizeof
(
vector_type
<
_BitInt
(
8
),
64
>
));
}
TEST
(
Custom_f8
,
TestAsType
)
{
struct
custom_f8_t
{
using
type
=
_BitInt
(
8
);
type
data
;
custom_f8_t
()
:
data
{
type
{}}
{}
custom_f8_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
_BitInt
(
8
)
>
test_vec
=
{
type_convert
<
_BitInt
(
8
)
>
(
0.3
f
),
type_convert
<
_BitInt
(
8
)
>
(
-
0.6
f
),
type_convert
<
_BitInt
(
8
)
>
(
0.8
f
),
type_convert
<
_BitInt
(
8
)
>
(
-
0.2
f
)};
// reference vector
vector_type
<
custom_f8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_f8_t
>()(
Number
<
i
>
{}).
data
,
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_f8_t
>()(
Number
<
i
>
{})
=
custom_f8_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_f8_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_f8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_f8
,
TestAsTypeReshape
)
{
struct
custom_f8_t
{
using
type
=
_BitInt
(
8
);
type
data
;
custom_f8_t
()
:
data
{
type
{}}
{}
custom_f8_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
_BitInt
(
8
)
>
test_vec
=
{
type_convert
<
_BitInt
(
8
)
>
(
0.3
f
),
type_convert
<
_BitInt
(
8
)
>
(
-
0.6
f
),
type_convert
<
_BitInt
(
8
)
>
(
0.8
f
),
type_convert
<
_BitInt
(
8
)
>
(
-
0.2
f
)};
// reference vector
vector_type
<
custom_f8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_f8_t
>()(
Number
<
i
>
{}).
data
,
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_f8_t
>()(
Number
<
i
>
{})
=
custom_f8_t
{
test_vec
.
at
(
i
)};
});
// copy the first half of a vector
vector_type
<
custom_f8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
custom_f8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_f8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_bf8
,
TestSize
)
{
struct
custom_bf8_t
{
unsigned
_BitInt
(
8
)
data
;
};
ASSERT_EQ
(
sizeof
(
custom_bf8_t
),
sizeof
(
unsigned
_BitInt
(
8
)));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bf8_t
,
2
>
),
sizeof
(
vector_type
<
unsigned
_BitInt
(
8
),
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bf8_t
,
4
>
),
sizeof
(
vector_type
<
unsigned
_BitInt
(
8
),
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bf8_t
,
8
>
),
sizeof
(
vector_type
<
unsigned
_BitInt
(
8
),
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bf8_t
,
16
>
),
sizeof
(
vector_type
<
unsigned
_BitInt
(
8
),
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bf8_t
,
32
>
),
sizeof
(
vector_type
<
unsigned
_BitInt
(
8
),
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bf8_t
,
64
>
),
sizeof
(
vector_type
<
unsigned
_BitInt
(
8
),
64
>
));
}
TEST
(
Custom_bf8
,
TestAsType
)
{
struct
custom_bf8_t
{
using
type
=
unsigned
_BitInt
(
8
);
type
data
;
custom_bf8_t
()
:
data
{
type
{}}
{}
custom_bf8_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
unsigned
_BitInt
(
8
)
>
test_vec
=
{
type_convert
<
unsigned
_BitInt
(
8
)
>
(
0.3
f
),
type_convert
<
unsigned
_BitInt
(
8
)
>
(
-
0.6
f
),
type_convert
<
unsigned
_BitInt
(
8
)
>
(
0.8
f
),
type_convert
<
unsigned
_BitInt
(
8
)
>
(
-
0.2
f
)};
// reference vector
vector_type
<
custom_bf8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_bf8_t
>()(
Number
<
i
>
{}).
data
,
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_bf8_t
>()(
Number
<
i
>
{})
=
custom_bf8_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_bf8_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_bf8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_bf8
,
TestAsTypeReshape
)
{
struct
custom_bf8_t
{
using
type
=
unsigned
_BitInt
(
8
);
type
data
;
custom_bf8_t
()
:
data
{
type
{}}
{}
custom_bf8_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
unsigned
_BitInt
(
8
)
>
test_vec
=
{
type_convert
<
unsigned
_BitInt
(
8
)
>
(
0.3
f
),
type_convert
<
unsigned
_BitInt
(
8
)
>
(
-
0.6
f
),
type_convert
<
unsigned
_BitInt
(
8
)
>
(
0.8
f
),
type_convert
<
unsigned
_BitInt
(
8
)
>
(
-
0.2
f
)};
// reference vector
vector_type
<
custom_bf8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_bf8_t
>()(
Number
<
i
>
{}).
data
,
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_bf8_t
>()(
Number
<
i
>
{})
=
custom_bf8_t
{
test_vec
.
at
(
i
)};
});
// copy the first half of a vector
vector_type
<
custom_bf8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
custom_bf8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_bf8_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_half
,
TestSize
)
{
struct
custom_half_t
{
half_t
data
;
};
ASSERT_EQ
(
sizeof
(
custom_half_t
),
sizeof
(
half_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_half_t
,
2
>
),
sizeof
(
vector_type
<
half_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_half_t
,
4
>
),
sizeof
(
vector_type
<
half_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_half_t
,
8
>
),
sizeof
(
vector_type
<
half_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_half_t
,
16
>
),
sizeof
(
vector_type
<
half_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_half_t
,
32
>
),
sizeof
(
vector_type
<
half_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_half_t
,
64
>
),
sizeof
(
vector_type
<
half_t
,
64
>
));
}
TEST
(
Custom_half
,
TestAsType
)
{
struct
custom_half_t
{
using
type
=
half_t
;
type
data
;
custom_half_t
()
:
data
{
type
{}}
{}
custom_half_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
half_t
>
test_vec
=
{
half_t
{
0.3
f
},
half_t
{
-
0.6
f
},
half_t
{
0.8
f
},
half_t
{
-
0.2
f
}};
// reference vector
vector_type
<
custom_half_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_half_t
>()(
Number
<
i
>
{}).
data
,
type_convert
<
half_t
>
(
0.0
f
));
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_half_t
>()(
Number
<
i
>
{})
=
custom_half_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_half_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_half_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_half
,
TestAsTypeReshape
)
{
struct
custom_half_t
{
using
type
=
half_t
;
type
data
;
custom_half_t
()
:
data
{
type
{}}
{}
custom_half_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
half_t
>
test_vec
=
{
half_t
{
0.3
f
},
half_t
{
-
0.6
f
},
half_t
{
0.8
f
},
half_t
{
-
0.2
f
}};
// reference vector
vector_type
<
custom_half_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_half_t
>()(
Number
<
i
>
{}).
data
,
type_convert
<
half_t
>
(
0.0
f
));
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_half_t
>()(
Number
<
i
>
{})
=
custom_half_t
{
test_vec
.
at
(
i
)};
});
// copy the first half of a vector
vector_type
<
custom_half_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
custom_half_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_half_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_bhalf
,
TestSize
)
{
struct
custom_bhalf_t
{
bhalf_t
data
;
};
ASSERT_EQ
(
sizeof
(
custom_bhalf_t
),
sizeof
(
bhalf_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bhalf_t
,
2
>
),
sizeof
(
vector_type
<
bhalf_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bhalf_t
,
4
>
),
sizeof
(
vector_type
<
bhalf_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bhalf_t
,
8
>
),
sizeof
(
vector_type
<
bhalf_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bhalf_t
,
16
>
),
sizeof
(
vector_type
<
bhalf_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bhalf_t
,
32
>
),
sizeof
(
vector_type
<
bhalf_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_bhalf_t
,
64
>
),
sizeof
(
vector_type
<
bhalf_t
,
64
>
));
}
TEST
(
Custom_bhalf
,
TestAsType
)
{
struct
custom_bhalf_t
{
using
type
=
bhalf_t
;
type
data
;
custom_bhalf_t
()
:
data
{
type
{}}
{}
custom_bhalf_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
bhalf_t
>
test_vec
=
{
type_convert
<
bhalf_t
>
(
0.3
f
),
type_convert
<
bhalf_t
>
(
-
0.6
f
),
type_convert
<
bhalf_t
>
(
0.8
f
),
type_convert
<
bhalf_t
>
(
-
0.2
f
)};
// reference vector
vector_type
<
custom_bhalf_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_bhalf_t
>()(
Number
<
i
>
{}).
data
,
type_convert
<
bhalf_t
>
(
0.0
f
));
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_bhalf_t
>()(
Number
<
i
>
{})
=
custom_bhalf_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_bhalf_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_bhalf_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_bhalf
,
TestAsTypeReshape
)
{
struct
custom_bhalf_t
{
using
type
=
bhalf_t
;
type
data
;
custom_bhalf_t
()
:
data
{
type
{}}
{}
custom_bhalf_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
bhalf_t
>
test_vec
=
{
type_convert
<
bhalf_t
>
(
0.3
f
),
type_convert
<
bhalf_t
>
(
-
0.6
f
),
type_convert
<
bhalf_t
>
(
0.8
f
),
type_convert
<
bhalf_t
>
(
-
0.2
f
)};
// reference vector
vector_type
<
custom_bhalf_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_bhalf_t
>()(
Number
<
i
>
{}).
data
,
type_convert
<
bhalf_t
>
(
0.0
f
));
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_bhalf_t
>()(
Number
<
i
>
{})
=
custom_bhalf_t
{
test_vec
.
at
(
i
)};
});
// copy the first half of a vector
vector_type
<
custom_bhalf_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
custom_bhalf_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_bhalf_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_float
,
TestSize
)
{
struct
custom_float_t
{
float
data
;
};
ASSERT_EQ
(
sizeof
(
custom_float_t
),
sizeof
(
float
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_float_t
,
2
>
),
sizeof
(
vector_type
<
float
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_float_t
,
4
>
),
sizeof
(
vector_type
<
float
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_float_t
,
8
>
),
sizeof
(
vector_type
<
float
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_float_t
,
16
>
),
sizeof
(
vector_type
<
float
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_float_t
,
32
>
),
sizeof
(
vector_type
<
float
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_float_t
,
64
>
),
sizeof
(
vector_type
<
float
,
64
>
));
}
TEST
(
Custom_float
,
TestAsType
)
{
struct
custom_float_t
{
using
type
=
float
;
type
data
;
custom_float_t
()
:
data
{
type
{}}
{}
custom_float_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
float
>
test_vec
=
{
0.3
f
,
-
0.6
f
,
0.8
f
,
-
0.2
f
};
// reference vector
vector_type
<
custom_float_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_float_t
>()(
Number
<
i
>
{}).
data
,
0.0
f
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_float_t
>()(
Number
<
i
>
{})
=
custom_float_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_float_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_float_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_float
,
TestAsTypeReshape
)
{
struct
custom_float_t
{
using
type
=
float
;
type
data
;
custom_float_t
()
:
data
{
type
{}}
{}
custom_float_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
float
>
test_vec
=
{
0.3
f
,
-
0.6
f
,
0.8
f
,
-
0.2
f
};
// reference vector
vector_type
<
custom_float_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_float_t
>()(
Number
<
i
>
{}).
data
,
0.0
f
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_float_t
>()(
Number
<
i
>
{})
=
custom_float_t
{
test_vec
.
at
(
i
)};
});
// copy the first half of a vector
vector_type
<
custom_float_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
custom_float_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_float_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_double
,
TestSize
)
{
struct
custom_double_t
{
double
data
;
};
ASSERT_EQ
(
sizeof
(
custom_double_t
),
sizeof
(
double
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_double_t
,
2
>
),
sizeof
(
vector_type
<
double
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_double_t
,
4
>
),
sizeof
(
vector_type
<
double
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_double_t
,
8
>
),
sizeof
(
vector_type
<
double
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_double_t
,
16
>
),
sizeof
(
vector_type
<
double
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_double_t
,
32
>
),
sizeof
(
vector_type
<
double
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
custom_double_t
,
64
>
),
sizeof
(
vector_type
<
double
,
64
>
));
}
TEST
(
Custom_double
,
TestAsType
)
{
struct
custom_double_t
{
using
type
=
double
;
type
data
;
custom_double_t
()
:
data
{
type
{}}
{}
custom_double_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
double
>
test_vec
=
{
0.3
,
0.6
,
0.8
,
0.2
};
// reference vector
vector_type
<
custom_double_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_double_t
>()(
Number
<
i
>
{}).
data
,
0.0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_double_t
>()(
Number
<
i
>
{})
=
custom_double_t
{
test_vec
.
at
(
i
)};
});
// copy the vector
vector_type
<
custom_double_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_double_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Custom_double
,
TestAsTypeReshape
)
{
struct
custom_double_t
{
using
type
=
double
;
type
data
;
custom_double_t
()
:
data
{
type
{}}
{}
custom_double_t
(
type
init
)
:
data
{
init
}
{}
};
// test size
const
int
size
=
4
;
std
::
vector
<
double
>
test_vec
=
{
0.3
,
0.6
,
0.8
,
0.2
};
// reference vector
vector_type
<
custom_double_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
custom_double_t
>()(
Number
<
i
>
{}).
data
,
0.0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
custom_double_t
>()(
Number
<
i
>
{})
=
custom_double_t
{
test_vec
.
at
(
i
)};
});
// copy the first half of a vector
vector_type
<
custom_double_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
custom_double_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
custom_double_t
>()(
Number
<
i
>
{}).
data
,
test_vec
.
at
(
i
));
});
}
TEST
(
Complex_half
,
TestSize
)
{
struct
complex_half_t
{
half_t
real
;
half_t
img
;
};
ASSERT_EQ
(
sizeof
(
complex_half_t
),
sizeof
(
half_t
)
+
sizeof
(
half_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
complex_half_t
,
2
>
),
sizeof
(
vector_type
<
half_t
,
2
>
)
+
sizeof
(
vector_type
<
half_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
complex_half_t
,
4
>
),
sizeof
(
vector_type
<
half_t
,
4
>
)
+
sizeof
(
vector_type
<
half_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
complex_half_t
,
8
>
),
sizeof
(
vector_type
<
half_t
,
8
>
)
+
sizeof
(
vector_type
<
half_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
complex_half_t
,
16
>
),
sizeof
(
vector_type
<
half_t
,
16
>
)
+
sizeof
(
vector_type
<
half_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
complex_half_t
,
32
>
),
sizeof
(
vector_type
<
half_t
,
32
>
)
+
sizeof
(
vector_type
<
half_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
complex_half_t
,
64
>
),
sizeof
(
vector_type
<
half_t
,
64
>
)
+
sizeof
(
vector_type
<
half_t
,
64
>
));
}
TEST
(
Complex_half
,
TestAlignment
)
{
struct
complex_half_t
{
half_t
real
;
half_t
img
;
};
ASSERT_EQ
(
alignof
(
vector_type
<
complex_half_t
,
2
>
),
alignof
(
vector_type
<
half_t
,
2
>
)
+
alignof
(
vector_type
<
half_t
,
2
>
));
ASSERT_EQ
(
alignof
(
vector_type
<
complex_half_t
,
4
>
),
alignof
(
vector_type
<
half_t
,
4
>
)
+
alignof
(
vector_type
<
half_t
,
4
>
));
ASSERT_EQ
(
alignof
(
vector_type
<
complex_half_t
,
8
>
),
alignof
(
vector_type
<
half_t
,
8
>
)
+
alignof
(
vector_type
<
half_t
,
8
>
));
ASSERT_EQ
(
alignof
(
vector_type
<
complex_half_t
,
16
>
),
alignof
(
vector_type
<
half_t
,
16
>
)
+
alignof
(
vector_type
<
half_t
,
16
>
));
ASSERT_EQ
(
alignof
(
vector_type
<
complex_half_t
,
32
>
),
alignof
(
vector_type
<
half_t
,
32
>
)
+
alignof
(
vector_type
<
half_t
,
32
>
));
ASSERT_EQ
(
alignof
(
vector_type
<
complex_half_t
,
64
>
),
alignof
(
vector_type
<
half_t
,
64
>
)
+
alignof
(
vector_type
<
half_t
,
64
>
));
}
TEST
(
Complex_half
,
TestAsType
)
{
struct
complex_half_t
{
using
type
=
half_t
;
type
real
;
type
img
;
complex_half_t
()
:
real
{
type
{}},
img
{
type
{}}
{}
complex_half_t
(
type
real_init
,
type
img_init
)
:
real
{
real_init
},
img
{
img_init
}
{}
};
// test size
const
int
size
=
4
;
// custom type number of elements
const
int
num_elem
=
sizeof
(
complex_half_t
)
/
sizeof
(
complex_half_t
::
type
);
std
::
vector
<
half_t
>
test_vec
=
{
half_t
{
0.3
f
},
half_t
{
-
0.6
f
},
half_t
{
0.8
f
},
half_t
{
-
0.2
f
},
half_t
{
0.5
f
},
half_t
{
-
0.7
f
},
half_t
{
0.9
f
},
half_t
{
-
0.3
f
}};
// reference vector
vector_type
<
complex_half_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{}).
real
,
type_convert
<
half_t
>
(
0.0
f
));
ASSERT_EQ
(
right_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{}).
img
,
type_convert
<
half_t
>
(
0.0
f
));
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{})
=
complex_half_t
{
test_vec
.
at
(
num_elem
*
i
),
test_vec
.
at
(
num_elem
*
i
+
1
)};
});
// copy the vector
vector_type
<
complex_half_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{}).
real
,
test_vec
.
at
(
num_elem
*
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{}).
img
,
test_vec
.
at
(
num_elem
*
i
+
1
));
});
}
TEST
(
Complex_half
,
TestAsTypeReshape
)
{
struct
complex_half_t
{
using
type
=
half_t
;
type
real
;
type
img
;
complex_half_t
()
:
real
{
type
{}},
img
{
type
{}}
{}
complex_half_t
(
type
real_init
,
type
img_init
)
:
real
{
real_init
},
img
{
img_init
}
{}
};
// test size
const
int
size
=
4
;
// custom type number of elements
const
int
num_elem
=
sizeof
(
complex_half_t
)
/
sizeof
(
complex_half_t
::
type
);
std
::
vector
<
half_t
>
test_vec
=
{
half_t
{
0.3
f
},
half_t
{
-
0.6
f
},
half_t
{
0.8
f
},
half_t
{
-
0.2
f
},
half_t
{
0.5
f
},
half_t
{
-
0.7
f
},
half_t
{
0.9
f
},
half_t
{
-
0.3
f
}};
// reference vector
vector_type
<
complex_half_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{}).
real
,
type_convert
<
half_t
>
(
0.0
f
));
ASSERT_EQ
(
right_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{}).
img
,
type_convert
<
half_t
>
(
0.0
f
));
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{})
=
complex_half_t
{
test_vec
.
at
(
num_elem
*
i
),
test_vec
.
at
(
num_elem
*
i
+
1
)};
});
// copy the first half of a vector
vector_type
<
complex_half_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
complex_half_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{}).
real
,
test_vec
.
at
(
num_elem
*
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
complex_half_t
>()(
Number
<
i
>
{}).
img
,
test_vec
.
at
(
num_elem
*
i
+
1
));
});
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment