Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8d2f2f8c
Commit
8d2f2f8c
authored
Dec 05, 2024
by
coderfeli
Browse files
Merge branch 'develop' into ck_tile/gemm_debug_alias
parents
99c8123f
4cb3d7d7
Changes
91
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1201 additions
and
136 deletions
+1201
-136
test/ck_tile/grouped_gemm/test_grouped_gemm.cpp
test/ck_tile/grouped_gemm/test_grouped_gemm.cpp
+29
-0
test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc
test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc
+25
-0
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
+282
-0
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+31
-6
test/data_type/test_bf8_fnuz.cpp
test/data_type/test_bf8_fnuz.cpp
+73
-62
test/data_type/test_bf8_ocp.cpp
test/data_type/test_bf8_ocp.cpp
+268
-0
test/data_type/test_custom_type.cpp
test/data_type/test_custom_type.cpp
+158
-0
test/data_type/test_fp8_fnuz.cpp
test/data_type/test_fp8_fnuz.cpp
+83
-66
test/data_type/test_fp8_ocp.cpp
test/data_type/test_fp8_ocp.cpp
+250
-0
test/pool/test_avg_pool2d_fwd.cpp
test/pool/test_avg_pool2d_fwd.cpp
+1
-1
test/pool/test_max_pool2d_fwd.cpp
test/pool/test_max_pool2d_fwd.cpp
+1
-1
No files found.
test/ck_tile/grouped_gemm/test_grouped_gemm.cpp
0 → 100644
View file @
8d2f2f8c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util.hpp"
using
F16
=
ck_tile
::
half_t
;
using
F32
=
float
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestCkTileGroupedGemm
,
KernelTypes
);
#include "test_grouped_gemm_ut_cases.inc"
test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc
0 → 100644
View file @
8d2f2f8c
#pragma once
TYPED_TEST
(
TestCkTileGroupedGemm
,
Basic
)
{
const
int
group_count
=
16
;
std
::
vector
<
int
>
Ms
;
std
::
vector
<
int
>
Ns
;
std
::
vector
<
int
>
Ks
;
std
::
vector
<
int
>
stride_As
;
std
::
vector
<
int
>
stride_Bs
;
std
::
vector
<
int
>
stride_Cs
;
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
Ms
.
push_back
(
256
+
256
*
i
);
Ns
.
push_back
(
128
+
128
*
i
);
Ks
.
push_back
(
128
+
64
*
i
);
stride_As
.
push_back
(
Ks
[
i
]);
stride_Bs
.
push_back
(
Ks
[
i
]);
stride_Cs
.
push_back
(
Ns
[
i
]);
}
this
->
Run
(
Ms
,
Ns
,
Ks
,
stride_As
,
stride_Bs
,
stride_Cs
,
group_count
);
}
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
0 → 100644
View file @
8d2f2f8c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
template
<
typename
Tuple
>
class
TestCkTileGroupedGemm
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
struct
GroupedGemKernelParam
{
static
const
bool
kPadM
=
false
;
static
const
bool
kPadN
=
false
;
static
const
bool
kPadK
=
false
;
static
const
bool
kTilePermute
=
false
;
static
const
ck_tile
::
index_t
kOutputRank
=
2
;
static
const
int
kBlockPerCu
=
1
;
static
const
ck_tile
::
index_t
M_Tile
=
128
;
static
const
ck_tile
::
index_t
N_Tile
=
128
;
static
const
ck_tile
::
index_t
K_Tile
=
32
;
static
const
ck_tile
::
index_t
M_Warp
=
2
;
static
const
ck_tile
::
index_t
N_Warp
=
2
;
static
const
ck_tile
::
index_t
K_Warp
=
1
;
static
const
ck_tile
::
index_t
M_Warp_Tile
=
32
;
static
const
ck_tile
::
index_t
N_Warp_Tile
=
32
;
static
const
ck_tile
::
index_t
K_Warp_Tile
=
8
;
};
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
GroupedGemKernelParam
::
M_Tile
,
GroupedGemKernelParam
::
N_Tile
,
GroupedGemKernelParam
::
K_Tile
>
,
ck_tile
::
sequence
<
GroupedGemKernelParam
::
M_Warp
,
GroupedGemKernelParam
::
N_Warp
,
GroupedGemKernelParam
::
K_Warp
>
,
ck_tile
::
sequence
<
GroupedGemKernelParam
::
M_Warp_Tile
,
GroupedGemKernelParam
::
N_Warp_Tile
,
GroupedGemKernelParam
::
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
template
<
typename
CLayout
>
using
GemmEpilogue
=
std
::
conditional_t
<
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
GroupedGemKernelParam
::
kPadM
,
GroupedGemKernelParam
::
kPadN
,
GroupedGemKernelParam
::
kTilePermute
,
GroupedGemKernelParam
::
kOutputRank
,
1
,
0
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
GroupedGemKernelParam
::
kPadM
,
GroupedGemKernelParam
::
kPadN
>>>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
GroupedGemKernelParam
::
kPadM
,
GroupedGemKernelParam
::
kPadN
,
GroupedGemKernelParam
::
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
<
ALayout
,
BLayout
,
CLayout
>>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>
,
CodegenGemmPolicy
>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
Kernel
=
ck_tile
::
GroupedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
<
ALayout
,
BLayout
,
CLayout
>
,
GemmEpilogue
<
CLayout
>>
;
using
grouped_gemm_kargs
=
ck_tile
::
GroupedGemmHostArgs
;
std
::
size_t
GetWorkspaceSize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
)
{
return
Kernel
<
std
::
nullptr_t
,
std
::
nullptr_t
,
std
::
nullptr_t
>::
GetWorkSpaceSize
(
gemm_descs
);
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
void
invoke_grouped_gemm
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
,
const
ck_tile
::
stream_config
&
s
,
void
*
p_workspace_
)
{
using
GroupedGemmKernel
=
Kernel
<
ALayout
,
BLayout
,
CLayout
>
;
auto
arguments
=
GroupedGemmKernel
::
MakeKargs
(
gemm_descs
);
const
dim3
grids
=
GroupedGemmKernel
::
GridSize
(
gemm_descs
);
constexpr
dim3
blocks
=
GroupedGemmKernel
::
BlockSize
();
ck_tile
::
hip_check_error
(
hipMemcpyWithStream
(
p_workspace_
,
arguments
.
data
(),
arguments
.
size
()
*
sizeof
(
typename
GroupedGemmKernel
::
GemmTransKernelArg
),
hipMemcpyHostToDevice
,
s
.
stream_id_
));
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
GroupedGemKernelParam
::
kBlockPerCu
>
(
GroupedGemmKernel
{},
grids
,
blocks
,
0
,
ck_tile
::
cast_pointer_to_constant_address_space
(
p_workspace_
),
gemm_descs
.
size
()));
}
public:
void
Run
(
const
std
::
vector
<
int
>&
Ms
,
const
std
::
vector
<
int
>&
Ns
,
const
std
::
vector
<
int
>&
Ks
,
std
::
vector
<
int
>&
stride_As
,
std
::
vector
<
int
>&
stride_Bs
,
std
::
vector
<
int
>&
stride_Cs
,
const
int
group_count
=
16
)
{
using
namespace
ck_tile
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
std
::
vector
<
ck_tile
::
HostTensor
<
ADataType
>>
a_m_k_tensors
;
std
::
vector
<
ck_tile
::
HostTensor
<
BDataType
>>
b_k_n_tensors
;
std
::
vector
<
ck_tile
::
HostTensor
<
CDataType
>>
c_m_n_tensors
;
a_m_k_tensors
.
reserve
(
group_count
);
b_k_n_tensors
.
reserve
(
group_count
);
c_m_n_tensors
.
reserve
(
group_count
);
std
::
vector
<
std
::
unique_ptr
<
ck_tile
::
DeviceMem
>>
a_m_k_dev_buf
;
std
::
vector
<
std
::
unique_ptr
<
ck_tile
::
DeviceMem
>>
b_k_n_dev_buf
;
std
::
vector
<
std
::
unique_ptr
<
ck_tile
::
DeviceMem
>>
c_m_n_dev_buf
;
a_m_k_dev_buf
.
reserve
(
group_count
);
b_k_n_dev_buf
.
reserve
(
group_count
);
c_m_n_dev_buf
.
reserve
(
group_count
);
std
::
vector
<
grouped_gemm_kargs
>
gemm_descs
;
gemm_descs
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
const
ck_tile
::
index_t
M
=
Ms
[
i
];
const
ck_tile
::
index_t
N
=
Ns
[
i
];
const
ck_tile
::
index_t
K
=
Ks
[
i
];
stride_As
[
i
]
=
f_get_default_stride
(
M
,
N
,
stride_As
[
i
],
ALayout
{});
stride_Bs
[
i
]
=
f_get_default_stride
(
K
,
N
,
stride_Bs
[
i
],
BLayout
{});
stride_Cs
[
i
]
=
f_get_default_stride
(
M
,
N
,
stride_Cs
[
i
],
CLayout
{});
a_m_k_tensors
.
push_back
(
ck_tile
::
HostTensor
<
ADataType
>
(
f_host_tensor_descriptor
(
M
,
K
,
stride_As
[
i
],
ALayout
{})));
b_k_n_tensors
.
push_back
(
ck_tile
::
HostTensor
<
BDataType
>
(
f_host_tensor_descriptor
(
K
,
N
,
stride_Bs
[
i
],
BLayout
{})));
c_m_n_tensors
.
push_back
(
ck_tile
::
HostTensor
<
CDataType
>
(
f_host_tensor_descriptor
(
M
,
N
,
stride_Cs
[
i
],
CLayout
{})));
std
::
cout
<<
"gemm["
<<
i
<<
"]"
<<
" a_m_k: "
<<
a_m_k_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_k_n_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_m_n_tensors
[
i
].
mDesc
<<
std
::
endl
;
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k_tensors
[
i
]);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n_tensors
[
i
]);
a_m_k_dev_buf
.
push_back
(
std
::
make_unique
<
ck_tile
::
DeviceMem
>
(
a_m_k_tensors
[
i
].
get_element_space_size_in_bytes
()));
b_k_n_dev_buf
.
push_back
(
std
::
make_unique
<
ck_tile
::
DeviceMem
>
(
b_k_n_tensors
[
i
].
get_element_space_size_in_bytes
()));
c_m_n_dev_buf
.
push_back
(
std
::
make_unique
<
ck_tile
::
DeviceMem
>
(
c_m_n_tensors
[
i
].
get_element_space_size_in_bytes
()));
a_m_k_dev_buf
[
i
]
->
ToDevice
(
a_m_k_tensors
[
i
].
data
());
b_k_n_dev_buf
[
i
]
->
ToDevice
(
b_k_n_tensors
[
i
].
data
());
c_m_n_dev_buf
[
i
]
->
SetZero
();
c_m_n_tensors
[
i
].
SetZero
();
const
void
*
p_a
=
a_m_k_dev_buf
[
i
]
->
GetDeviceBuffer
();
const
void
*
p_b
=
b_k_n_dev_buf
[
i
]
->
GetDeviceBuffer
();
void
*
p_c
=
c_m_n_dev_buf
[
i
]
->
GetDeviceBuffer
();
gemm_descs
.
push_back
(
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
stride_As
[
i
],
stride_Bs
[
i
],
stride_Cs
[
i
]});
}
ck_tile
::
DeviceMem
gemm_workspace
;
gemm_workspace
.
Realloc
(
GetWorkspaceSize
(
gemm_descs
));
invoke_grouped_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
gemm_descs
,
ck_tile
::
stream_config
{
nullptr
,
false
},
gemm_workspace
.
GetDeviceBuffer
());
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
c_m_n_dev_buf
[
i
]
->
FromDevice
(
c_m_n_tensors
[
i
].
data
());
}
bool
pass
{
true
};
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
stride_Cs
[
i
],
CLayout
{}));
c_m_n_host_ref
.
SetZero
();
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k_tensors
[
i
],
b_k_n_tensors
[
i
],
c_m_n_host_ref
);
pass
&=
ck_tile
::
check_err
(
c_m_n_tensors
[
i
],
c_m_n_host_ref
);
}
EXPECT_TRUE
(
pass
);
}
};
test/data_type/CMakeLists.txt
View file @
8d2f2f8c
...
@@ -9,13 +9,38 @@ if (USE_BITINT_EXTENSION_INT4)
...
@@ -9,13 +9,38 @@ if (USE_BITINT_EXTENSION_INT4)
endif
()
endif
()
endif
()
endif
()
add_gtest_executable
(
test_fp8 test_fp8.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp8 PRIVATE utility
)
add_custom_target
(
test_fp8
)
if
(
CK_USE_OCP_FP8
)
add_gtest_executable
(
test_fp8_ocp test_fp8_ocp.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp8_ocp PRIVATE utility
)
endif
()
add_gtest_executable
(
test_bf8_ocp test_bf8_ocp.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_bf8_ocp PRIVATE utility
)
endif
()
add_dependencies
(
test_fp8 test_fp8_ocp
)
add_dependencies
(
test_fp8 test_bf8_ocp
)
endif
()
endif
()
add_gtest_executable
(
test_bf8 test_bf8.cpp
)
if
(
result EQUAL 0
)
if
(
CK_USE_FNUZ_FP8
)
target_link_libraries
(
test_bf8 PRIVATE utility
)
add_gtest_executable
(
test_fp8_fnuz test_fp8_fnuz.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp8_fnuz PRIVATE utility
)
endif
()
add_gtest_executable
(
test_bf8_fnuz test_bf8_fnuz.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_bf8_fnuz PRIVATE utility
)
endif
()
add_dependencies
(
test_fp8 test_fp8_fnuz
)
add_dependencies
(
test_fp8 test_bf8_fnuz
)
endif
()
endif
()
add_gtest_executable
(
test_custom_type test_custom_type.cpp
)
add_gtest_executable
(
test_custom_type test_custom_type.cpp
)
...
...
test/data_type/test_bf8.cpp
→
test/data_type/test_bf8
_fnuz
.cpp
View file @
8d2f2f8c
...
@@ -5,158 +5,169 @@
...
@@ -5,158 +5,169 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bf8_t
;
using
ck
::
bf8_
fnuz_
t
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_convert_sr
;
using
ck
::
half_t
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
using
ck
::
type_convert
;
TEST
(
BF8
,
NumericLimits
)
TEST
(
BF8
FNUZ
,
NumericLimits
)
{
{
// constants given for negative zero nan mode
// constants given for negative zero nan mode
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Min
(),
type_convert
<
bf8_t
>
(
0x04
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
Min
(),
type_convert
<
bf8_
fnuz_
t
>
(
0x04
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Max
(),
type_convert
<
bf8_t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
Max
(),
type_convert
<
bf8_
fnuz_
t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Lowest
(),
type_convert
<
bf8_t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
Lowest
(),
type_convert
<
bf8_
fnuz_
t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
QuietNaN
(),
type_convert
<
bf8_t
>
(
0x80
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
QuietNaN
(),
type_convert
<
bf8_
fnuz_
t
>
(
0x80
));
}
}
TEST
(
BF8
,
ConvertFP32Nearest
)
TEST
(
BF8
FNUZ
,
ConvertFP32Nearest
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-6
;
float
abs_tol
=
1e-6
;
// convert 0 float to bf8 and back, check if holds
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// don't run the next test on gfx11 devices
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to bf8 and back, check if holds
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
abs_tol
);
#endif
#endif
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
57344.0
f
)),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_fnuz_t
>
(
max_bf8_t_float
)),
abs_tol
);
// convert maximal float to bf8 and back, check if clipped to 57344.0
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
57344.0
f
,
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
abs_tol
);
// convert inf float to bf8_t and check if it is qNan
// convert inf float to bf8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
f8_convert_rne
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
abs_tol
);
// positive norm float value to bf8 and back, check if holds
// positive norm float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
float
pos_float
=
0.0000762939
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to bf8 and back, check if holds
// negative norm float value to bf8 and back, check if holds
float
neg_float
=
-
0.0000610351
f
;
float
neg_float
=
-
0.0000610351
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to bf8 and back, check if holds
// positive subnorm float value to bf8 and back, check if holds
pos_float
=
0.0000305175
f
;
pos_float
=
0.0000305175
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to bf8 and back, check if holds
// negative subnorm float value to bf8 and back, check if holds
neg_float
=
-
0.0000152587
f
;
neg_float
=
-
0.0000152587
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
}
TEST
(
BF8
,
ConvertFP32Stochastic
)
TEST
(
BF8
FNUZ
,
ConvertFP32Stochastic
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-6
;
float
abs_tol
=
1e-6
;
// convert 0 float to bf8 and back, check if holds
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// convert minimal float to bf8 and back, check if holds
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
abs_tol
);
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
57344.0
f
)),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_fnuz_t
>
(
max_bf8_t_float
)),
abs_tol
);
// convert maximal float to bf8 and back, check if clipped to 57344.0
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
57344.0
f
,
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
abs_tol
);
// convert inf float to bf8_t and check if it is qNan
// convert inf float to bf8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
f8_convert_sr
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
abs_tol
);
// positive norm float value to bf8 and back, check if holds
// positive norm float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
float
pos_float
=
0.0000762939
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to bf8 and back, check if holds
// negative norm float value to bf8 and back, check if holds
float
neg_float
=
-
0.0000610351
f
;
float
neg_float
=
-
0.0000610351
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to bf8 and back, check if holds
// positive subnorm float value to bf8 and back, check if holds
pos_float
=
0.0000305175
f
;
pos_float
=
0.0000305175
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to bf8 and back, check if holds
// negative subnorm float value to bf8 and back, check if holds
neg_float
=
-
0.0000152587
f
;
neg_float
=
-
0.0000152587
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
}
TEST
(
BF8
,
ConvertFP16Nearest
)
TEST
(
BF8
FNUZ
,
ConvertFP16Nearest
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-3
;
float
abs_tol
=
1e-3
;
// convert 0 fp16 to bf8 and back, check if holds
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_fnuz_t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to bf8 and back, check if holds
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
abs_tol
);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const
auto
max_bf8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR
(
ASSERT_NEAR
(
half_t
{
57344.0
}
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
t
>
(
half_t
{
57344.0
}
)),
abs_tol
);
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_t
>
(
max_bf8_t_half
)),
abs_tol
);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
}
,
ASSERT_NEAR
(
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
abs_tol
);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
// convert QuietNaN fp16 to bf8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
f8_convert_rne
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
abs_tol
);
// positive norm fp16 value to bf8 and back, check if holds
// positive norm fp16 value to bf8 and back, check if holds
half_t
pos_half
=
half_t
{
0.0000762939
};
half_t
pos_half
=
half_t
{
0.0000762939
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to bf8 and back, check if holds
// negative norm fp16 value to bf8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.0000610351
};
half_t
neg_half
=
half_t
{
-
0.0000610351
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to bf8 and back, check if holds
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half
=
half_t
{
0.0000305175
};
pos_half
=
half_t
{
0.0000305175
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to bf8 and back, check if holds
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half
=
half_t
{
-
0.0000152587
};
neg_half
=
half_t
{
-
0.0000152587
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
}
TEST
(
BF8
,
ConvertFP16Stochastic
)
TEST
(
BF8
FNUZ
,
ConvertFP16Stochastic
)
{
{
// fix the tolerance value
// fix the tolerance value
float
abs_tol
=
1e-3
;
float
abs_tol
=
1e-3
;
// convert 0 fp16 to bf8 and back, check if holds
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to bf8 and back, check if holds
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
abs_tol
);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const
auto
max_bf8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR
(
ASSERT_NEAR
(
half_t
{
57344.0
}
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
t
>
(
half_t
{
57344.0
}
)),
abs_tol
);
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_t
>
(
max_bf8_t_half
)),
abs_tol
);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
}
,
ASSERT_NEAR
(
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
abs_tol
);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
// convert QuietNaN fp16 to bf8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
f8_convert_sr
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
abs_tol
);
// positive norm fp16 value to bf8 and back, check if holds
// positive norm fp16 value to bf8 and back, check if holds
half_t
pos_half
=
half_t
{
0.0000762939
};
half_t
pos_half
=
half_t
{
0.0000762939
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to bf8 and back, check if holds
// negative norm fp16 value to bf8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.0000610351
};
half_t
neg_half
=
half_t
{
-
0.0000610351
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to bf8 and back, check if holds
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half
=
half_t
{
0.0000305175
};
pos_half
=
half_t
{
0.0000305175
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to bf8 and back, check if holds
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half
=
half_t
{
-
0.0000152587
};
neg_half
=
half_t
{
-
0.0000152587
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
}
test/data_type/test_bf8_ocp.cpp
0 → 100644
View file @
8d2f2f8c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bf8_ocp_t
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
TEST
(
BF8OCP
,
NumericLimits
)
{
// constants given for OCP FP8
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Min
(),
type_convert
<
bf8_ocp_t
>
(
0x04
));
// 0b00000100 = 2^-14
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
type_convert
<
bf8_ocp_t
>
(
0x7B
));
// 0b01111011 = 57344
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Lowest
(),
type_convert
<
bf8_ocp_t
>
(
0xFB
));
// 0b11111011 = -57344
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
QuietNaN
().
data
,
type_convert
<
bf8_ocp_t
>
(
0x7D
).
data
);
// 0b01111101
EXPECT_FALSE
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
QuietNaN
()
==
ck
::
NumericLimits
<
bf8_ocp_t
>::
QuietNaN
());
EXPECT_TRUE
(
ck
::
fp8_is_inf
(
type_convert
<
bf8_ocp_t
>
(
0xFC
))
&&
ck
::
fp8_is_inf
(
type_convert
<
bf8_ocp_t
>
(
0x7C
)));
}
TEST
(
BF8OCP
,
ConvertFP32Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
0.0
f
)),
0.0
f
);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
max_bf8_t_float
)),
0.0
f
);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
0.0
f
);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()));
// positive normal float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
// 10*2^-17
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_float
)),
abs_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
auto
neg_min_bf8
=
-
0.00006103515625
f
;
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
neg_min_bf8
)),
0.0
f
);
// positive subnorm float value to bf8 and back, check if holds
constexpr
auto
pos_subnorm_bf8
=
0.000030517578125
f
;
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
0.0
f
);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr
auto
min_subnorm_bf8
=
-
0.0000152587890625
f
;
//-2^-16
ASSERT_NEAR
(
min_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
min_subnorm_bf8
)),
0.0
f
);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr
auto
less_than_min_subnorm
=
0.00000762939453125
f
;
// 2^-17
ASSERT_EQ
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
TEST
(
BF8OCP
,
ConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
0.0
f
)),
0.0
f
);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
max_bf8_t_float
)),
0.0
f
);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
0.0
f
);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()));
// positive normal float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
// 10*2^-17
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_float
)),
abs_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
auto
neg_min_bf8
=
-
0.00006103515625
f
;
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
neg_min_bf8
)),
0.0
f
);
// positive subnorm float value to bf8 and back, check if holds
constexpr
auto
pos_subnorm_bf8
=
0.000030517578125
f
;
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
0.0
f
);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr
auto
min_subnorm_bf8
=
-
0.0000152587890625
f
;
//-2^-16
ASSERT_NEAR
(
min_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
min_subnorm_bf8
)),
0.0
f
);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr
auto
less_than_min_subnorm
=
0.00000762939453125
f
;
// 2^-17
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
less_than_min_subnorm
)),
0.0000152587890625
f
);
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
TEST
(
BF8OCP
,
ConvertFP16Nearest
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
half_t_tol
);
const
auto
max_bf8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
max_bf8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_rne
<
bf8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_norm_bf8
{
0.0000762939
f
};
// 10*2^-17
ASSERT_NEAR
(
pos_norm_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_norm_bf8
)),
half_t_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
half_t
neg_min_bf8
{
-
0.00006103515625
f
};
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
neg_min_bf8
)),
half_t_zero
);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_subnorm_bf8
{
0.000030517578125
f
};
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
half_t_zero
);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr
half_t
min_subnorm_bf8
{
-
0.0000152587890625
f
};
//-2^-16
ASSERT_NEAR
(
min_subnorm_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
min_subnorm_bf8
)),
half_t_zero
);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr
half_t
less_than_min_subnorm
{
0.00000762939453125
f
};
// 2^-17
ASSERT_EQ
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_rne
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
TEST
(
BF8OCP
,
ConvertFP16Stochastic
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
constexpr
auto
min_subnorm_bf8
=
0.0000152587890625
f
;
// 2^-16
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t (6.103515625e-05) to fp8 and back
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
half_t_zero
);
const
auto
max_bf8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
max_bf8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_sr
<
bf8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_norm_bf8
{
0.0000762939
f
};
// 10*2^-17
ASSERT_NEAR
(
pos_norm_bf8
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_norm_bf8
)),
half_t_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
half_t
neg_min_bf8
{
-
0.00006103515625
f
};
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
neg_min_bf8
)),
half_t_zero
);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_subnorm_bf8
{
0.000030517578125
f
};
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
half_t_zero
);
// min subnorm bf8 value to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
-
min_subnorm_bf8
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
half_t
{
-
min_subnorm_bf8
})),
half_t_zero
);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr
half_t
less_than_min_subnorm
{
0.00000762939453125
f
};
// 2^-17
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
less_than_min_subnorm
)),
half_t
{
min_subnorm_bf8
});
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_sr
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
test/data_type/test_custom_type.cpp
View file @
8d2f2f8c
...
@@ -872,3 +872,161 @@ TEST(Complex_half, TestAsTypeReshape)
...
@@ -872,3 +872,161 @@ TEST(Complex_half, TestAsTypeReshape)
test_vec
.
at
(
num_elem
*
i
+
1
));
test_vec
.
at
(
num_elem
*
i
+
1
));
});
});
}
}
#if CK_USE_OCP_FP8
TEST
(
FP8OCP
,
TestSize
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
ASSERT_EQ
(
sizeof
(
f8_t
),
sizeof
(
ck
::
fp8_storage_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
2
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
4
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
8
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
16
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
32
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
64
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
64
>
));
}
TEST
(
FP8OCP
,
TestAsType
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
4
,
-
2
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
f8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
f8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
));
});
// copy the vector
vector_type
<
f8_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
)));
});
ck
::
non_native_vector_base
<
ck
::
f8_ocp_t
,
2
>
nnvb_f8x2
(
ck
::
type_convert
<
f8_t
>
(
-
10.0
f
));
ASSERT_EQ
(
nnvb_f8x2
.
template
AsType
<
f8_t
>()(
Number
<
0
>
{}),
ck
::
type_convert
<
f8_t
>
(
-
10.0
f
));
ASSERT_EQ
(
nnvb_f8x2
.
template
AsType
<
f8_t
>()(
Number
<
1
>
{}),
ck
::
type_convert
<
f8_t
>
(
-
10.0
f
));
}
TEST
(
FP8OCP
,
TestAsTypeReshape
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
8
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
/
256
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
f8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
f8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
));
});
// copy the first half of a vector
vector_type
<
f8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
f8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
)));
});
}
TEST
(
BF8OCP
,
TestSize
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
ASSERT_EQ
(
sizeof
(
bf8_t
),
sizeof
(
ck
::
fp8_storage_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
2
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
4
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
8
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
16
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
32
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
64
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
64
>
));
}
TEST
(
BF8OCP
,
TestAsType
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
4
,
-
2
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
bf8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
bf8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
));
});
// copy the vector
vector_type
<
bf8_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
)));
});
ck
::
non_native_vector_base
<
bf8_t
,
2
>
nnvb_bf8x2
(
ck
::
type_convert
<
bf8_t
>
(
-
10.0
f
));
ASSERT_EQ
(
nnvb_bf8x2
.
template
AsType
<
bf8_t
>()(
Number
<
0
>
{}),
ck
::
type_convert
<
bf8_t
>
(
-
10.0
f
));
ASSERT_EQ
(
nnvb_bf8x2
.
template
AsType
<
bf8_t
>()(
Number
<
1
>
{}),
ck
::
type_convert
<
bf8_t
>
(
-
10.0
f
));
}
TEST
(
BF8OCP
,
TestAsTypeReshape
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
8
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
/
256
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
bf8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
bf8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
));
});
// copy the first half of a vector
vector_type
<
bf8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
bf8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
)));
});
}
#endif
test/data_type/test_fp8.cpp
→
test/data_type/test_fp8
_fnuz
.cpp
View file @
8d2f2f8c
This diff is collapsed.
Click to expand it.
test/data_type/test_fp8_ocp.cpp
0 → 100644
View file @
8d2f2f8c
This diff is collapsed.
Click to expand it.
test/pool/test_avg_pool2d_fwd.cpp
View file @
8d2f2f8c
...
@@ -138,7 +138,7 @@ TYPED_TEST_SUITE(AvgPool2D_BF16, AvgPool2D_BF16_Types);
...
@@ -138,7 +138,7 @@ TYPED_TEST_SUITE(AvgPool2D_BF16, AvgPool2D_BF16_Types);
TYPED_TEST_SUITE
(
AvgPool2D_I8
,
AvgPool2D_I8_Types
);
TYPED_TEST_SUITE
(
AvgPool2D_I8
,
AvgPool2D_I8_Types
);
TYPED_TEST_SUITE
(
AvgPool2D_F8
,
AvgPool2D_F8_Types
);
TYPED_TEST_SUITE
(
AvgPool2D_F8
,
AvgPool2D_F8_Types
);
TYPED_TEST
(
AvgPool2D_F32
,
AvgPool2D_
I8
_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_F32
,
AvgPool2D_
F32
_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_F16
,
AvgPool2D_F16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_F16
,
AvgPool2D_F16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_BF16
,
AvgPool2D_BF16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_BF16
,
AvgPool2D_BF16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_I8
,
AvgPool2D_I8_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_I8
,
AvgPool2D_I8_Test
)
{
this
->
Run
();
}
...
...
test/pool/test_max_pool2d_fwd.cpp
View file @
8d2f2f8c
...
@@ -143,7 +143,7 @@ TYPED_TEST_SUITE(MaxPool2D_BF16, MaxPool2D_BF16_Types);
...
@@ -143,7 +143,7 @@ TYPED_TEST_SUITE(MaxPool2D_BF16, MaxPool2D_BF16_Types);
TYPED_TEST_SUITE
(
MaxPool2D_I8
,
MaxPool2D_I8_Types
);
TYPED_TEST_SUITE
(
MaxPool2D_I8
,
MaxPool2D_I8_Types
);
TYPED_TEST_SUITE
(
MaxPool2D_F8
,
MaxPool2D_F8_Types
);
TYPED_TEST_SUITE
(
MaxPool2D_F8
,
MaxPool2D_F8_Types
);
TYPED_TEST
(
MaxPool2D_F32
,
MaxPool2D_
I8
_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_F32
,
MaxPool2D_
F32
_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_F16
,
MaxPool2D_F16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_F16
,
MaxPool2D_F16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_BF16
,
MaxPool2D_BF16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_BF16
,
MaxPool2D_BF16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_I8
,
MaxPool2D_I8_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_I8
,
MaxPool2D_I8_Test
)
{
this
->
Run
();
}
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment