Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
175a17f8
Commit
175a17f8
authored
Nov 23, 2024
by
Rostyslav Geyyer
Browse files
Merge branch 'gfx950' of
https://github.com/ROCm/composable_kernel-internal
into lwpck-2390
parents
3e520bbd
1504c3e8
Changes
138
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
900 additions
and
164 deletions
+900
-164
profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp
profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp
+2
-2
profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
...clude/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
+2
-2
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
...ofiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
+2
-2
profiler/include/profiler/profile_gemm_impl.hpp
profiler/include/profiler/profile_gemm_impl.hpp
+3
-3
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
+16
-2
profiler/src/profile_gemm_universal.cpp
profiler/src/profile_gemm_universal.cpp
+3
-3
profiler/src/profile_layernorm_fwd.cpp
profiler/src/profile_layernorm_fwd.cpp
+1
-1
script/process_perf_data.py
script/process_perf_data.py
+2
-2
script/process_qa_data.sh
script/process_qa_data.sh
+1
-0
test/CMakeLists.txt
test/CMakeLists.txt
+5
-5
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
+6
-6
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+31
-6
test/data_type/test_bf8_fnuz.cpp
test/data_type/test_bf8_fnuz.cpp
+73
-62
test/data_type/test_bf8_ocp.cpp
test/data_type/test_bf8_ocp.cpp
+268
-0
test/data_type/test_custom_type.cpp
test/data_type/test_custom_type.cpp
+150
-0
test/data_type/test_fp8_fnuz.cpp
test/data_type/test_fp8_fnuz.cpp
+83
-66
test/data_type/test_fp8_ocp.cpp
test/data_type/test_fp8_ocp.cpp
+250
-0
test/gemm_universal/test_gemm_universal_xdl.cpp
test/gemm_universal/test_gemm_universal_xdl.cpp
+2
-2
No files found.
profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp
View file @
175a17f8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -157,7 +157,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
break
;
default:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
...
...
profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
View file @
175a17f8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -174,7 +174,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
break
;
default:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
...
...
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
View file @
175a17f8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -140,7 +140,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
break
;
default:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
...
...
profiler/include/profiler/profile_gemm_impl.hpp
View file @
175a17f8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -74,8 +74,8 @@ int profile_gemm_impl(int do_verification,
switch
(
init_method
)
{
case
0
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
static_cas
t
<
ADataType
>
(
1.
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
static_cas
t
<
BDataType
>
(
1.
f
)}(
b_k_n
);
ck
::
utils
::
FillConstant
<
ADataType
>
{
type_conver
t
<
ADataType
>
(
1.
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
type_conver
t
<
BDataType
>
(
1.
f
)}(
b_k_n
);
break
;
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
...
...
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
View file @
175a17f8
...
...
@@ -240,6 +240,19 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
{
out_device_buf
.
FromDevice
(
out_n_c_do_ho_wo_device
.
mData
.
data
());
auto
number_of_accumulations
=
1
;
static_assert
(
ReduceOpId
==
ck
::
ReduceTensorOp
::
AVG
||
ReduceOpId
==
ck
::
ReduceTensorOp
::
MAX
,
"Warning: Unhandled ReduceOpId for setting up the number of accumulations!"
);
if
constexpr
(
ReduceOpId
==
ck
::
ReduceTensorOp
::
AVG
)
{
for
(
size_t
i
=
0
;
i
<
kernel_params
.
window_spatial_lengths
.
size
();
++
i
)
{
number_of_accumulations
*=
kernel_params
.
window_spatial_lengths
.
at
(
i
);
}
}
auto
absolute_error_threshold
=
1.0
;
switch
(
in_params
.
init_method
)
{
...
...
@@ -250,9 +263,10 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
absolute_error_threshold
=
ck
::
utils
::
get_absolute_threshold
<
ComputeDataType
,
OutDataType
>
(
absolute_error_threshold
);
absolute_error_threshold
,
number_of_accumulations
);
auto
relative_error_threshold
=
ck
::
utils
::
get_relative_threshold
<
ComputeDataType
,
OutDataType
>
();
ck
::
utils
::
get_relative_threshold
<
ComputeDataType
,
OutDataType
>
(
number_of_accumulations
);
bool
pass
=
ck
::
utils
::
check_err
(
out_n_c_do_ho_wo_device
.
mData
,
out_n_c_do_ho_wo_host
.
mData
,
...
...
profiler/src/profile_gemm_universal.cpp
View file @
175a17f8
...
...
@@ -101,7 +101,7 @@ int profile_gemm_universal(int argc, char* argv[])
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|| defined(CK_USE_GFX94)
using
F8
=
ck
::
f8_t
;
#endif
...
...
@@ -164,7 +164,7 @@ int profile_gemm_universal(int argc, char* argv[])
{
return
profile
(
F16
{},
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|| defined(CK_USE_GFX94)
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F16
{},
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
...
...
@@ -198,7 +198,7 @@ int profile_gemm_universal(int argc, char* argv[])
{
return
profile
(
BF16
{},
BF16
{},
BF16
{},
F32
{},
BF16
{},
Col
{},
Row
{},
Row
{});
}
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|| defined(CK_USE_GFX94)
else
if
(
data_type
==
GemmDataType
::
F8_F8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
BF16
{},
Row
{},
Row
{},
Row
{});
...
...
profiler/src/profile_layernorm_fwd.cpp
View file @
175a17f8
...
...
@@ -85,7 +85,7 @@ int profile_layernorm(int argc, char* argv[])
if
(
data_type
==
ck
::
DataTypeEnum
::
Half
)
{
ck
::
profiler
::
profile_layernorm_impl
<
F16
,
F16
,
F16
,
F32
,
F16
,
F
32
,
false
,
rank
>
(
ck
::
profiler
::
profile_layernorm_impl
<
F16
,
F16
,
F16
,
F32
,
F16
,
F
16
,
false
,
rank
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
);
}
else
if
(
data_type
==
ck
::
DataTypeEnum
::
Float
)
...
...
script/process_perf_data.py
View file @
175a17f8
...
...
@@ -133,12 +133,12 @@ def parse_logfile(logfile):
if
'Best Perf'
in
line
:
lst
=
line
.
split
()
res
.
append
(
lst
[
4
])
elif
'onnx_gemm'
in
logfile
or
'mixed_gemm'
in
logfile
:
elif
'onnx_gemm'
in
logfile
:
for
line
in
open
(
logfile
):
if
'Best Perf'
in
line
:
lst
=
line
.
split
()
res
.
append
(
lst
[
33
])
elif
'splitK_gemm'
in
logfile
:
elif
'splitK_gemm'
in
logfile
or
'mixed_gemm'
in
logfile
:
for
line
in
open
(
logfile
):
if
'Best Perf'
in
line
:
lst
=
line
.
split
()
...
...
script/process_qa_data.sh
View file @
175a17f8
...
...
@@ -22,6 +22,7 @@ python3 process_perf_data.py perf_gemm_bilinear.log
python3 process_perf_data.py perf_reduction.log
python3 process_perf_data.py perf_splitK_gemm.log
python3 process_perf_data.py perf_onnx_gemm.log
python3 process_perf_data.py perf_mixed_gemm.log
file
=
./perf_fmha_fwd_gfx942.log
if
[
-e
"
$file
"
]
;
then
...
...
test/CMakeLists.txt
View file @
175a17f8
...
...
@@ -64,11 +64,11 @@ function(add_test_executable TEST_NAME)
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN MATCHES
"_xdl"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950
)
elseif
(
ARGN MATCHES
"_smfmac"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201
gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
...
@@ -141,11 +141,11 @@ function(add_gtest_executable TEST_NAME)
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN MATCHES
"_xdl"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950
)
elseif
(
ARGN MATCHES
"_smfmac"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201
gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
...
@@ -206,7 +206,7 @@ add_subdirectory(wrapper)
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
endif
()
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx942"
AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2
)
# smfmac needs ROCm6.2
if
(
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx942"
OR SUPPORTED_GPU_TARGETS MATCHES
"gfx95"
)
AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2
)
# smfmac needs ROCm6.2
add_subdirectory
(
smfmac_op
)
endif
()
add_subdirectory
(
position_embedding
)
...
...
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
View file @
175a17f8
...
...
@@ -53,9 +53,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
bool
kPad
A
=
true
;
constexpr
bool
kPad
B
=
true
;
constexpr
bool
kPad
C
=
true
;
constexpr
bool
kPad
M
=
true
;
constexpr
bool
kPad
N
=
true
;
constexpr
bool
kPad
K
=
true
;
constexpr
int
kBlockPerCu
=
1
;
...
...
@@ -68,9 +68,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
false
,
kPad
C
>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPad
N
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPad
A
,
kPad
B
,
kPad
C
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPad
M
,
kPad
N
,
kPad
K
,
ALayout
,
BLayout
,
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
...
...
@@ -108,7 +108,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Lunching kernel with args:"
std
::
cout
<<
"L
a
unching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
...
...
test/data_type/CMakeLists.txt
View file @
175a17f8
...
...
@@ -9,13 +9,38 @@ if (USE_BITINT_EXTENSION_INT4)
endif
()
endif
()
add_gtest_executable
(
test_fp8 test_fp8.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp8 PRIVATE utility
)
add_custom_target
(
test_fp8
)
if
(
CK_USE_OCP_FP8
)
add_gtest_executable
(
test_fp8_ocp test_fp8_ocp.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp8_ocp PRIVATE utility
)
endif
()
add_gtest_executable
(
test_bf8_ocp test_bf8_ocp.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_bf8_ocp PRIVATE utility
)
endif
()
add_dependencies
(
test_fp8 test_fp8_ocp
)
add_dependencies
(
test_fp8 test_bf8_ocp
)
endif
()
add_gtest_executable
(
test_bf8 test_bf8.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_bf8 PRIVATE utility
)
if
(
CK_USE_FNUZ_FP8
)
add_gtest_executable
(
test_fp8_fnuz test_fp8_fnuz.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp8_fnuz PRIVATE utility
)
endif
()
add_gtest_executable
(
test_bf8_fnuz test_bf8_fnuz.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_bf8_fnuz PRIVATE utility
)
endif
()
add_dependencies
(
test_fp8 test_fp8_fnuz
)
add_dependencies
(
test_fp8 test_bf8_fnuz
)
endif
()
add_gtest_executable
(
test_fp4 test_fp4.cpp
)
if
(
result EQUAL 0
)
...
...
test/data_type/test_bf8.cpp
→
test/data_type/test_bf8
_fnuz
.cpp
View file @
175a17f8
...
...
@@ -5,158 +5,169 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bf8_t
;
using
ck
::
bf8_
fnuz_
t
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
TEST
(
BF8
,
NumericLimits
)
TEST
(
BF8
FNUZ
,
NumericLimits
)
{
// constants given for negative zero nan mode
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Min
(),
type_convert
<
bf8_t
>
(
0x04
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Max
(),
type_convert
<
bf8_t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
Lowest
(),
type_convert
<
bf8_t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_t
>::
QuietNaN
(),
type_convert
<
bf8_t
>
(
0x80
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
Min
(),
type_convert
<
bf8_
fnuz_
t
>
(
0x04
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
Max
(),
type_convert
<
bf8_
fnuz_
t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
Lowest
(),
type_convert
<
bf8_
fnuz_
t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_
fnuz_
t
>::
QuietNaN
(),
type_convert
<
bf8_
fnuz_
t
>
(
0x80
));
}
TEST
(
BF8
,
ConvertFP32Nearest
)
TEST
(
BF8
FNUZ
,
ConvertFP32Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
#endif
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
57344.0
f
)),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_fnuz_t
>
(
max_bf8_t_float
)),
abs_tol
);
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
// convert inf float to bf8_t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
f8_convert_rne
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
// convert inf float to bf8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
// positive norm float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to bf8 and back, check if holds
float
neg_float
=
-
0.0000610351
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to bf8 and back, check if holds
pos_float
=
0.0000305175
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to bf8 and back, check if holds
neg_float
=
-
0.0000152587
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
TEST
(
BF8
,
ConvertFP32Stochastic
)
TEST
(
BF8
FNUZ
,
ConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
57344.0
f
)),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_fnuz_t
>
(
max_bf8_t_float
)),
abs_tol
);
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
57344.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
// convert inf float to bf8_t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
f8_convert_sr
<
bf8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
// convert inf float to bf8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
bf8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
// positive norm float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to bf8 and back, check if holds
float
neg_float
=
-
0.0000610351
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to bf8 and back, check if holds
pos_float
=
0.0000305175
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to bf8 and back, check if holds
neg_float
=
-
0.0000152587
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
TEST
(
BF8
,
ConvertFP16Nearest
)
TEST
(
BF8
FNUZ
,
ConvertFP16Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-3
;
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_fnuz_t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const
auto
max_bf8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
}
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
t
>
(
half_t
{
57344.0
}
)),
abs_tol
);
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_t
>
(
max_bf8_t_half
)),
abs_tol
);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
}
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
ASSERT_NEAR
(
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
f8_convert_rne
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
// convert QuietNaN fp16 to bf8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
// positive norm fp16 value to bf8 and back, check if holds
half_t
pos_half
=
half_t
{
0.0000762939
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to bf8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.0000610351
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half
=
half_t
{
0.0000305175
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half
=
half_t
{
-
0.0000152587
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
TEST
(
BF8
,
ConvertFP16Stochastic
)
TEST
(
BF8
FNUZ
,
ConvertFP16Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-3
;
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const
auto
max_bf8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
Max
());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
}
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
t
>
(
half_t
{
57344.0
}
)),
abs_tol
);
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_t
>
(
max_bf8_t_half
)),
abs_tol
);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR
(
half_t
{
57344.0
}
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
ASSERT_NEAR
(
max_bf8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
bf8_t
>
(
0x80
),
f8_convert_sr
<
bf8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
// convert QuietNaN fp16 to bf8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
ck
::
NumericLimits
<
bf8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
bf8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
// positive norm fp16 value to bf8 and back, check if holds
half_t
pos_half
=
half_t
{
0.0000762939
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to bf8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.0000610351
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half
=
half_t
{
0.0000305175
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half
=
half_t
{
-
0.0000152587
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
test/data_type/test_bf8_ocp.cpp
0 → 100644
View file @
175a17f8
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bf8_ocp_t
;
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
TEST
(
BF8OCP
,
NumericLimits
)
{
// constants given for OCP FP8
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Min
(),
type_convert
<
bf8_ocp_t
>
(
0x04
));
// 0b00000100 = 2^-14
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
type_convert
<
bf8_ocp_t
>
(
0x7B
));
// 0b01111011 = 57344
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Lowest
(),
type_convert
<
bf8_ocp_t
>
(
0xFB
));
// 0b11111011 = -57344
EXPECT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
QuietNaN
().
data
,
type_convert
<
bf8_ocp_t
>
(
0x7D
).
data
);
// 0b01111101
EXPECT_FALSE
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
QuietNaN
()
==
ck
::
NumericLimits
<
bf8_ocp_t
>::
QuietNaN
());
EXPECT_TRUE
(
ck
::
fp8_is_inf
(
type_convert
<
bf8_ocp_t
>
(
0xFC
))
&&
ck
::
fp8_is_inf
(
type_convert
<
bf8_ocp_t
>
(
0x7C
)));
}
TEST
(
BF8OCP
,
ConvertFP32Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
0.0
f
)),
0.0
f
);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
max_bf8_t_float
)),
0.0
f
);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
0.0
f
);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()));
// positive normal float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
// 10*2^-17
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_float
)),
abs_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
auto
neg_min_bf8
=
-
0.00006103515625
f
;
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
neg_min_bf8
)),
0.0
f
);
// positive subnorm float value to bf8 and back, check if holds
constexpr
auto
pos_subnorm_bf8
=
0.000030517578125
f
;
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
0.0
f
);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr
auto
min_subnorm_bf8
=
-
0.0000152587890625
f
;
//-2^-16
ASSERT_NEAR
(
min_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
min_subnorm_bf8
)),
0.0
f
);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr
auto
less_than_min_subnorm
=
0.00000762939453125
f
;
// 2^-17
ASSERT_EQ
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
TEST
(
BF8OCP
,
ConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
0.0
f
)),
0.0
f
);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
const
auto
max_bf8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
max_bf8_t_float
)),
0.0
f
);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
0.0
f
);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()));
// positive normal float value to bf8 and back, check if holds
float
pos_float
=
0.0000762939
f
;
// 10*2^-17
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_float
)),
abs_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
auto
neg_min_bf8
=
-
0.00006103515625
f
;
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
neg_min_bf8
)),
0.0
f
);
// positive subnorm float value to bf8 and back, check if holds
constexpr
auto
pos_subnorm_bf8
=
0.000030517578125
f
;
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
0.0
f
);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr
auto
min_subnorm_bf8
=
-
0.0000152587890625
f
;
//-2^-16
ASSERT_NEAR
(
min_subnorm_bf8
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
min_subnorm_bf8
)),
0.0
f
);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr
auto
less_than_min_subnorm
=
0.00000762939453125
f
;
// 2^-17
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
less_than_min_subnorm
)),
0.0000152587890625
f
);
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_sr
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
TEST
(
BF8OCP
,
ConvertFP16Nearest
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t to bf8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
half_t_tol
);
const
auto
max_bf8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
max_bf8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_rne
<
bf8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_norm_bf8
{
0.0000762939
f
};
// 10*2^-17
ASSERT_NEAR
(
pos_norm_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_norm_bf8
)),
half_t_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
half_t
neg_min_bf8
{
-
0.00006103515625
f
};
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
neg_min_bf8
)),
half_t_zero
);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_subnorm_bf8
{
0.000030517578125
f
};
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
half_t_zero
);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr
half_t
min_subnorm_bf8
{
-
0.0000152587890625
f
};
//-2^-16
ASSERT_NEAR
(
min_subnorm_bf8
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
min_subnorm_bf8
)),
half_t_zero
);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr
half_t
less_than_min_subnorm
{
0.00000762939453125
f
};
// 2^-17
ASSERT_EQ
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
bf8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_rne
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
TEST
(
BF8OCP
,
ConvertFP16Stochastic
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
constexpr
auto
min_subnorm_bf8
=
0.0000152587890625
f
;
// 2^-16
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t (6.103515625e-05) to fp8 and back
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
half_t_zero
);
const
auto
max_bf8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
max_bf8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR
(
max_bf8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
(),
f8_convert_sr
<
bf8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_norm_bf8
{
0.0000762939
f
};
// 10*2^-17
ASSERT_NEAR
(
pos_norm_bf8
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_norm_bf8
)),
half_t_tol
);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr
half_t
neg_min_bf8
{
-
0.00006103515625
f
};
//-2^-14
ASSERT_NEAR
(
neg_min_bf8
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
neg_min_bf8
)),
half_t_zero
);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr
half_t
pos_subnorm_bf8
{
0.000030517578125
f
};
// 2^-15
ASSERT_NEAR
(
pos_subnorm_bf8
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
pos_subnorm_bf8
)),
half_t_zero
);
// min subnorm bf8 value to bf8 and back, check if holds
ASSERT_NEAR
(
half_t
{
-
min_subnorm_bf8
},
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
half_t
{
-
min_subnorm_bf8
})),
half_t_zero
);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr
half_t
less_than_min_subnorm
{
0.00000762939453125
f
};
// 2^-17
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_sr
<
bf8_ocp_t
>
(
less_than_min_subnorm
)),
half_t
{
min_subnorm_bf8
});
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const
auto
bf8_nan
=
f8_convert_sr
<
bf8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_bf8_is_nan
(
bf8_nan
.
data
));
}
test/data_type/test_custom_type.cpp
View file @
175a17f8
...
...
@@ -872,3 +872,153 @@ TEST(Complex_half, TestAsTypeReshape)
test_vec
.
at
(
num_elem
*
i
+
1
));
});
}
#if CK_USE_OCP_FP8
TEST
(
FP8OCP
,
TestSize
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
ASSERT_EQ
(
sizeof
(
f8_t
),
sizeof
(
ck
::
fp8_storage_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
2
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
4
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
8
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
16
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
32
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
f8_t
,
64
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
64
>
));
}
TEST
(
FP8OCP
,
TestAsType
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
4
,
-
2
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
f8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
f8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
));
});
// copy the vector
vector_type
<
f8_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
)));
});
}
TEST
(
FP8OCP
,
TestAsTypeReshape
)
{
static_assert
(
std
::
is_same_v
<
f8_t
,
ck
::
f8_ocp_t
>
,
"OCP FP8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
8
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
/
256
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
f8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
f8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
));
});
// copy the first half of a vector
vector_type
<
f8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
f8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
f8_t
>
(
test_vec
.
at
(
i
)));
});
}
TEST
(
BF8OCP
,
TestSize
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
ASSERT_EQ
(
sizeof
(
bf8_t
),
sizeof
(
ck
::
fp8_storage_t
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
2
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
2
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
4
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
4
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
8
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
8
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
16
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
16
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
32
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
32
>
));
ASSERT_EQ
(
sizeof
(
vector_type
<
bf8_t
,
64
>
),
sizeof
(
vector_type
<
ck
::
fp8_storage_t
,
64
>
));
}
TEST
(
BF8OCP
,
TestAsType
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
4
,
-
2
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
bf8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
bf8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
));
});
// copy the vector
vector_type
<
bf8_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
)));
});
}
TEST
(
BF8OCP
,
TestAsTypeReshape
)
{
static_assert
(
std
::
is_same_v
<
bf8_t
,
ck
::
bf8_ocp_t
>
,
"OCP BF8 is not enabled"
);
// test size
std
::
array
<
float
,
8
>
test_vec
=
{
-
8
,
-
0.5
,
-
0.25
,
1.0
/
8.0
,
1
/
256
,
1
,
1.5
,
16
};
constexpr
int
size
=
test_vec
.
size
();
// reference vector
vector_type
<
bf8_t
,
size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}(
[
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
bf8_t
{
0
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{})
=
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
));
});
// copy the first half of a vector
vector_type
<
bf8_t
,
size
/
2
>
left_vec
{
right_vec
.
template
AsType
<
vector_type
<
bf8_t
,
size
/
2
>
::
type
>
()(
Number
<
0
>
{})};
// check if values were copied correctly
ck
::
static_for
<
0
,
size
/
2
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
bf8_t
>()(
Number
<
i
>
{}),
ck
::
type_convert
<
bf8_t
>
(
test_vec
.
at
(
i
)));
});
}
#endif
test/data_type/test_fp8.cpp
→
test/data_type/test_fp8
_fnuz
.cpp
View file @
175a17f8
...
...
@@ -7,154 +7,171 @@
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_t
;
using
ck
::
f8_
fnuz_
t
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
TEST
(
FP8
,
NumericLimits
)
TEST
(
FP8
FNUZ
,
NumericLimits
)
{
// constants given for negative zero nan mode
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Min
(),
type_convert
<
f8_t
>
(
0x08
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Max
(),
type_convert
<
f8_t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
Lowest
(),
type_convert
<
f8_t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_t
>::
QuietNaN
(),
type_convert
<
f8_t
>
(
0x80
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
Min
(),
type_convert
<
f8_
fnuz_
t
>
(
0x08
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
Max
(),
type_convert
<
f8_
fnuz_
t
>
(
0x7F
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
Lowest
(),
type_convert
<
f8_
fnuz_
t
>
(
0xFF
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_
fnuz_
t
>::
QuietNaN
(),
type_convert
<
f8_
fnuz_
t
>
(
0x80
));
}
TEST
(
FP8
,
ConvertFP32Nearest
)
TEST
(
FP8
FNUZ
,
ConvertFP32Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
#endif
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
240.0
f
)),
abs_tol
);
// convert maximal float to fp8 and back, check if clipped to 240.0
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
const
auto
max_f8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
max_f8_t_float
)),
abs_tol
);
// XXX: FNUZ f8_convert_rne behavior is inconsistent.
// Clipping large values to fp8 max (saturation to finite) contradicts converting inf float to
// fp8 qNAN (no saturation).
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
// convert inf float to f8_t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
f8_convert_rne
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
// convert inf float to f8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp8 and back, check if holds
neg_float
=
-
0.001953125
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
TEST
(
FP8
,
ConvertFP32Stochastic
)
TEST
(
FP8
FNUZ
,
ConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
0.0
f
)),
abs_tol
);
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
0.0
f
)),
abs_tol
);
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
240.0
f
)),
abs_tol
);
// convert maximal float to fp8 and back, check if clipped to 240.0
ASSERT_NEAR
(
240.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
const
auto
max_f8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
max_f8_t_float
)),
abs_tol
);
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
abs_tol
);
// convert inf float to f8_t and check if it is qNan
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
f8_convert_sr
<
f8_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
// convert inf float to f8_
fnuz_
t and check if it is qNan
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
f8_
fnuz_
t
>
(
std
::
numeric_limits
<
float
>::
infinity
()),
abs_tol
);
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative norm float value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
pos_float
)),
abs_tol
);
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_float
)),
abs_tol
);
// negative subnorm float value to fp8 and back, check if holds
neg_float
=
-
0.001953125
f
;
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_t
>
(
neg_float
)),
abs_tol
);
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_float
)),
abs_tol
);
}
TEST
(
FP8
,
ConvertFP16Nearest
)
TEST
(
FP8
FNUZ
,
ConvertFP16Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-3
;
// convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
half_t
{
240.0
})),
abs_tol
);
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
const
auto
max_f8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
max_f8_t_half
)),
abs_tol
);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_fnuz_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
f8_convert_rne
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
// convert QuietNaN fp16 to f8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_rne
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
// positive norm fp16 value to fp8 and back, check if holds
half_t
pos_half
=
half_t
{
0.017578125
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to fp8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.015625
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half
=
half_t
{
0.00390625
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to fp8 and back, check if holds
neg_half
=
half_t
{
-
0.001953125
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
TEST
(
FP8
,
ConvertFP16Stochastic
)
TEST
(
FP8
FNUZ
,
ConvertFP16Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-3
;
// convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
half_t
{
0.0
})),
abs_tol
);
ASSERT_NEAR
(
half_t
{
0.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
half_t
{
0.0
})),
abs_tol
);
// convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
abs_tol
);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
half_t
{
240.0
})),
abs_tol
);
// convert maximal fp16 to fp8 and back, check if clipped to 240.0
ASSERT_NEAR
(
half_t
{
240.0
},
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
const
auto
max_f8_t_half
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
Max
());
// convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
max_f8_t_half
)),
abs_tol
);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR
(
max_f8_t_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_fnuz_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
abs_tol
);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR
(
type_convert
<
f8_t
>
(
0x80
),
f8_convert_sr
<
f8_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
// convert QuietNaN fp16 to f8_
fnuz_
t and check if it is QuietNaN
ASSERT_NEAR
(
ck
::
NumericLimits
<
f8_fnuz_t
>::
QuietNaN
(
),
f8_convert_sr
<
f8_
fnuz_
t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
()),
abs_tol
);
// positive norm fp16 value to fp8 and back, check if holds
half_t
pos_half
=
half_t
{
0.017578125
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative norm fp16 value to fp8 and back, check if holds
half_t
neg_half
=
half_t
{
-
0.015625
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half
=
half_t
{
0.00390625
};
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
pos_half
)),
abs_tol
);
ASSERT_NEAR
(
pos_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
pos_half
)),
abs_tol
);
// negative subnorm fp16 value to fp8 and back, check if holds
neg_half
=
half_t
{
-
0.001953125
};
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_t
>
(
neg_half
)),
abs_tol
);
ASSERT_NEAR
(
neg_half
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_
fnuz_
t
>
(
neg_half
)),
abs_tol
);
}
test/data_type/test_fp8_ocp.cpp
0 → 100644
View file @
175a17f8
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
f8_convert_rne
;
using
ck
::
f8_convert_sr
;
using
ck
::
f8_ocp_t
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
TEST
(
FP8OCP
,
NumericLimits
)
{
// constants given for OCP FP8
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Min
(),
type_convert
<
f8_ocp_t
>
(
0x08
));
// 0b00001000 = 2^-6
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
type_convert
<
f8_ocp_t
>
(
0x7E
));
// 0b01111110 = 448
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Lowest
(),
type_convert
<
f8_ocp_t
>
(
0xFE
));
// 0b11111110 = -448
EXPECT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
QuietNaN
().
data
,
type_convert
<
f8_ocp_t
>
(
0x7F
).
data
);
// 0b01111111
EXPECT_FALSE
(
ck
::
NumericLimits
<
f8_ocp_t
>::
QuietNaN
()
==
ck
::
NumericLimits
<
f8_ocp_t
>::
QuietNaN
());
}
TEST
(
FP8OCP
,
ConvertFP32Nearest
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
0.0
f
)),
0.0
f
);
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
const
auto
max_f8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
());
// convert maximal f8_ocp_t to float and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
max_f8_t_float
)),
0.0
f
);
// convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
0.0
f
);
// convert float infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
f8_convert_rne
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()));
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_float
)),
abs_tol
);
// smallest normal fp8 value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
//-2^-6
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_float
)),
0.0
f
);
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_float
)),
abs_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
neg_float
=
-
0.001953125
f
;
//-2^-9
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_float
)),
0.0
f
);
// smaller than min subnorm fp8 value to fp8 must be zero
auto
less_than_min_subnorm
=
0.0009765625
f
;
// 2^-10
ASSERT_EQ
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto
f8_nan
=
f8_convert_rne
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
());
ASSERT_TRUE
((
f8_nan
.
data
&
0x7f
)
==
0x7f
);
}
TEST
(
FP8OCP
,
ConvertFP32Stochastic
)
{
// fix the tolerance value
float
abs_tol
=
1e-6
;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
0.0
f
)),
0.0
f
);
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR
(
std
::
numeric_limits
<
float
>::
min
(),
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
min
())),
abs_tol
);
const
auto
max_f8_t_float
=
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
());
// convert maximal f8_ocp_t to float and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
max_f8_t_float
)),
0.0
f
);
// convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR
(
max_f8_t_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
max
())),
0.0
f
);
// convert float infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
f8_convert_sr
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
()));
// positive norm float value to fp8 and back, check if holds
float
pos_float
=
0.017578125
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_float
)),
abs_tol
);
// smallest normal fp8 value to fp8 and back, check if holds
float
neg_float
=
-
0.015625
f
;
//-2^-6
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
neg_float
)),
0.0
f
);
// positive subnorm float value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_float
)),
abs_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
constexpr
auto
min_subnorm_fp8
=
-
0.001953125
f
;
//-2^-9
ASSERT_NEAR
(
min_subnorm_fp8
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
min_subnorm_fp8
)),
0.0
f
);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto
less_than_min_subnorm
=
0.0009765625
f
;
// 2^-10
ASSERT_NEAR
(
0.0
f
,
type_convert
<
float
>
(
f8_convert_sr
<
f8_ocp_t
>
(
less_than_min_subnorm
)),
0.001953125
f
);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto
f8_nan
=
f8_convert_sr
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
());
ASSERT_TRUE
((
f8_nan
.
data
&
0x7f
)
==
0x7f
);
}
TEST
(
FP8OCP
,
ConvertFP16Nearest
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t to fp8 and back, check if holds
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
half_t_tol
);
const
auto
max_f8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
max_f8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
f8_convert_rne
<
f8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive norm half_t value to fp8 and back, check if holds
half_t
pos_half_t
{
0.017578125
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// smallest normal fp8 value to fp8 and back, check if holds
half_t
neg_half_t
{
-
0.015625
f
};
//-2^-6
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t
=
half_t
{
0.00390625
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t
=
half_t
{
-
0.001953125
f
};
//-2^-9
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// smaller than min subnorm fp8 value to fp8 must be zero
auto
less_than_min_subnorm
=
half_t
{
0.0009765625
f
};
// 2^-10
ASSERT_EQ
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_rne
<
f8_ocp_t
>
(
less_than_min_subnorm
)));
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto
f8_nan
=
f8_convert_rne
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_f8_is_nan
(
f8_nan
.
data
));
}
TEST
(
FP8OCP
,
ConvertFP16Stochastic
)
{
// fix the tolerance value
constexpr
half_t
half_t_tol
=
1e-3
;
constexpr
half_t
half_t_zero
=
0.0
;
constexpr
auto
min_subnorm_fp8
=
0.001953125
f
;
// 2^-9
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR
(
half_t_zero
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
half_t_zero
)),
half_t_zero
);
// convert minimal half_t (6.103515625e-05) to fp8 and back
// alternates between 0 and 2^-9 (0.001953125)
ASSERT_NEAR
(
ck
::
NumericLimits
<
half_t
>::
Min
(),
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Min
())),
type_convert
<
half_t
>
(
min_subnorm_fp8
));
const
auto
max_f8_t_half_t
=
type_convert
<
half_t
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
max_f8_t_half_t
)),
half_t_zero
);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR
(
max_f8_t_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
Max
())),
half_t_zero
);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
(),
f8_convert_sr
<
f8_ocp_t
>
(
type_convert
<
half_t
>
(
std
::
numeric_limits
<
float
>::
infinity
())));
// positive norm half_t value to fp8 and back, check if holds
half_t
pos_half_t
{
0.017578125
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// smallest normal fp8 value to fp8 and back, check if holds
half_t
neg_half_t
{
-
0.015625
f
};
//-2^-6
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t
=
half_t
{
0.00390625
f
};
ASSERT_NEAR
(
pos_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
pos_half_t
)),
half_t_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t
=
half_t
{
-
min_subnorm_fp8
};
//-2^-9
ASSERT_NEAR
(
neg_half_t
,
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
neg_half_t
)),
half_t_zero
);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto
less_than_min_subnorm
=
half_t
{
0.0009765625
f
};
// 2^-10
ASSERT_NEAR
(
type_convert
<
float
>
(
half_t_zero
),
type_convert
<
float
>
(
type_convert
<
half_t
>
(
f8_convert_sr
<
f8_ocp_t
>
(
less_than_min_subnorm
))),
min_subnorm_fp8
);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto
f8_nan
=
f8_convert_sr
<
f8_ocp_t
>
(
ck
::
NumericLimits
<
half_t
>::
QuietNaN
());
ASSERT_TRUE
(
ck
::
fp8_impl
::
ocp_f8_is_nan
(
f8_nan
.
data
));
}
test/gemm_universal/test_gemm_universal_xdl.cpp
View file @
175a17f8
...
...
@@ -56,7 +56,7 @@ class TestGemmUniversal_KM_NK
using
KernelTypes_MK_KN
=
::
testing
::
Types
<
// ADataType, BDataType, ComputeDataType, CDataType
std
::
tuple
<
F16
,
F16
,
F16
,
F16
>
,
#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
#if defined(CK_ENABLE_FP8) &&
(
defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|| defined(CK_USE_GFX94))
std
::
tuple
<
F16
,
F8
,
F16
,
F16
>
,
std
::
tuple
<
F8
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
F8
,
F8
,
F8
,
BF16
>
,
...
...
@@ -66,7 +66,7 @@ using KernelTypes_MK_KN = ::testing::Types<
using
KernelTypes_MK_NK
=
::
testing
::
Types
<
// ADataType, BDataType, ComputeDataType, CDataType
std
::
tuple
<
F16
,
F16
,
F16
,
F16
>
,
#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
#if defined(CK_ENABLE_FP8) &&
(
defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|| defined(CK_USE_GFX94))
std
::
tuple
<
F16
,
F8
,
F16
,
F16
>
,
std
::
tuple
<
F8
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
F8
,
F8
,
F8
,
BF16
>
,
...
...
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment