Unverified Commit 08d5c02c authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

OCP FP8 support for gfx12. (#1710)

* (2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds

* (3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds

* (4/5) grouped conv pass

* (5/5) attention pass, todo: debug lds perf bug

* AIT Attention API refactor (#8)

* sanity pass

* sanity pass 2

* confirm significant performance regression.

* turn on all instances

* turn off instance format

* Fix bug & tunning & format

* DML meta, self_attn+cross_attn

* sanity pass

* remove useless flag

* update tile and problem size used in AIT attention

* bug fix in grouped conv supporting check

* deprecate inline asm wmma

* Bug fix: double lds skip

* clang-format

* Fix errors in
1. example, fmha
2. gridwise pipeline
3. deviceop, fmha, change some containers from vector to array

* part2 of previous commit

* clang format

* API fix of gridwisegemmpipeline

* separate array base and vector base attention tensor transformation

* fix gemm

* clang format

* add gemm fp16 instances

* Temp save

* fpAintB kernel compile pass

* Sanity pass.

* Temp save

* debug code enabled

* Fp16AInt8B_GEMM sanity

* MQA implementation

* GQA-4 example

* tempsave

* Compile pass

* New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm

* Bump rocm-docs-core from 0.24.0 to 0.29.0 in /docs/sphinx

Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.24.0 to 0.29.0.
- [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases)
- [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md)
- [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.24.0...v0.29.0

)

---
updated-dependencies:
- dependency-name: rocm-docs-core
  dependency-type: direct:production
  update-type: version-update:semver-minor
...
Signed-off-by: default avatardependabot[bot] <support@github.com>

* initial enablement of gfx950

* fix clang format

* disable examples 31 and 41 int8 on gfx950

* initial navi4x enablement

* remove extra endif

* enabled dl_gemm

* update s_barrier and s_waitcnt for gfx12

* fix the gfx12 assembly syntax

* fixed block_sync_lds

* add support for more dl kernels on navi4

* add wmma

* format

* Todo: fix gemm_bilinear_wmma instances compilation bug

* Solve a bug when K1=16

* remove unnecessary changes

* Remove tensor layout limitation to LDS usage in tesnor contraction

* fixed block_sync_lds

* merge navi3_ref

* update self-attention and cross-attention

* fix a typo of name

* fixed layout

* debugging

* Add arch limiter for fp8 gemm

* fixed wmma

* enable fp8 gemm_xdl for all gfx9 targets

* temporarily disable gemm_xdl_fp16_fp8 on MI100/200

* fix the cmake logic for gemm_xdl_fp16_fp8

* fixed c_output

* re-enable the gemm_xdl_fp16_fp8 on MI100/200

* fixed gfx12

* fixed

* fixed

* seperate gfx12 blockwise_gemm

* fixed

* enable fwd conv on navi4x

* enable gridwise

* enabled gemm

* fixed merge

* remove empty example fold

* fixed conflicts

* some small changes

* Update cmake-ck-dev.sh

* Update cmake-ck-dev.sh

* enabled other types

* fixed register loads

* test fa

* enable gfx12

* clean up

* enable some instances on gfx12

* add gfx1201 macro in amd_wmma header

* fix clang format

* enable batched_gemm_softmax_gemm_perm_wmma for gfx12

* disable instances with blocksize=256 in attention examples

* debuggging

* debug

* fixed lds_enabled

* debugging

* Fix and add limit to skiplds feature

* Enable skipLds feature and fix compilation bugs

* add ck_tile definitions for gfx12

* fix clang format and test/wmma_op

* updage instances cmake for gfx12

* disable the test_wmma_op on gfx12

* fix the builds for gfx950

* add gfx12 and gfx950 to default target list

* clean-up cmake file

* Initial introduction of OFP8 data types.

* Renamed FP8 and BF8 tests into FP8_FNUZ and BF8_FNUZ.

* Implementation of ConvertFP32Nearest in test_fp8_ocp.

* Remove dependence on possibly undeclared alias.

* Implement FP8OCP test for stochastic rounding mode.

* Implement FP8OCP tests for half_t type conversions.

* enable bf16 atomic add on gfx950

* Implement ConvertFP32Nearest test.

* Implement ConvertFP32Stochastic test.

* Implement ConvertFP16Nearest and ConvertFP16Stochastic tests.

* Refactoring. Move FP8 definitions into a separate header file.

* Enable easy switching between architectures.

* Fix compilation error for gfx942 architecture.

* only builf gfx950 branch for gfx950 target by default

* Enable OCP build of example_gemm_xdl_fp8.

* Fix formatting.

* fix the build logic for gfx950

* Improve GEMM example verbosity.

* Add constexpr where applicable.

* fix the logic of enabling XDL and WMMA instances

* Improve GEMM example verbosity.

* Enable build of example_gemm_xdl_fp8_bf8 test.

* Fix tests for gfx1101 architecture.

* Build DPP examples only on gfx103 and gfx11 architectures.

* Optionaly run either CPU or GPU verifications with GEMM examples.

* Extend GeneratorTensor_Sequential to produce values of prescribed data types.

* Add missing constructor.

* Improve infrastructure for OFP8 data type support.

* BUGFIX. Should not use FP8 as Compute/Accum data type.

* Add custom target for grouped_convnd_bwd_weight tests.

* Can build `tests` target on gfx950.

* Bugfixes on gfx1101 architecture.

* Fix dependencies.

* Provide single point of truth for FP8 INF and NAN checks

* Prevent instantiation of operators that are not supported by FP8 data types

* Add FP8 type selection into client_axample CMakeLists.txt

* Prevent sccache server from shutting down during build

* Fix test success reporting logic

* Change default verification method to CPU.

GPU verification takes too much time to complete on the emulator.

* Make sure all tests and examples are built for gfx950

* Facilitate testing of FP8 data types on the emulator

* Introduce two new tensor generators

* Enable instances built for gfx94 to be built on gfx950

* Verify 35_splitk_gemm on floating point numbers.

splitk gemm appears to be losing precision VS reference implementation when FP numbers are involved.

* Verify 04_gemm_add_add_fastgelu on floating point numbers

* Verify 20_grouped_conv_bwd_weight on floating point numbers

* Verify 38_grouped_conv_bwd_data_multiple_d on floating point numbers

* Verify more tests on floating point data

* Fix data types and improve testing verbocity.

* Upgrade to NPI 573 build docker.

* Skip on gemm_universal tests.

The tests take too long to complete on the emulator.
Need to see if it is possible to reduce the scope of the testing to just FP8 data types.

* Fix gfx1101 build

* Document test availability

* Re-enable fp8 gemms for gfx94/95

* Cherry-pick GEMM Universal tests for FP8 data types

* Cleanup

* CK_USE_GFX94 has already been set on this branch

* Address formatting issues and leftovers

* Make fail/pass logic consistent within 01_gemm folder

Removed multiple negations in fail/pass logic to propagate `true` as the success indicator.

* Fix GPU verification reporting logic.

* Update year in copyright notice.

* Cleanup

* Use `enum class` instead of `enum`

* Remove set_property for FP8 tests

* Narrowing the scope of PR to OCP FP8 enablement only

* Add tests for OCP FP8 vector_type storage

* Enable gemm kernel on all gfx9 architectures (#227)

* clean-up

* Implement `non_native_vector_base` with `ext_vector_type` array. (#232)

* Enable support of 1, 2, 4, and 8-byte custom types in CK.

* Fix pool tests for OCP FP8 data type

* fix jenkins file

* restore cron trigger

---------
Signed-off-by: default avatardependabot[bot] <support@github.com>
Co-authored-by: default avataraska-0096 <haocwang@amd.com>
Co-authored-by: default avatardependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatarJun Liu <Liu.Jun@amd.com>
Co-authored-by: default avatarAndriy Roshchenko <andriy.roshchenko@amd.com>
Co-authored-by: default avatarAndriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
parent 50ee4267
...@@ -62,7 +62,7 @@ function(add_instance_library INSTANCE_NAME) ...@@ -62,7 +62,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach() endforeach()
# Do not build mha instances if gfx94 or gfx90a targets are not on the target list # Do not build mha instances if gfx94 or gfx90a targets are not on the target list
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha") if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha")
message("removing mha instance ${source} ") message("removing mha instance ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
...@@ -346,7 +346,7 @@ if(CK_DEVICE_CONV_INSTANCES) ...@@ -346,7 +346,7 @@ if(CK_DEVICE_CONV_INSTANCES)
endif() endif()
if(CK_DEVICE_MHA_INSTANCES) if(CK_DEVICE_MHA_INSTANCES)
set(gpu_list ${INST_TARGETS}) set(gpu_list ${INST_TARGETS})
if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a") if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a")
add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES}) add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES})
add_library(composablekernels::device_mha_operations ALIAS device_mha_operations) add_library(composablekernels::device_mha_operations ALIAS device_mha_operations)
target_compile_features(device_mha_operations PUBLIC) target_compile_features(device_mha_operations PUBLIC)
......
...@@ -15,7 +15,7 @@ void add_device_pool3d_fwd_ndhwc_f8_instances( ...@@ -15,7 +15,7 @@ void add_device_pool3d_fwd_ndhwc_f8_instances(
instances) instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F8, ReduceOpId, false>{}); instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F32, ReduceOpId, false>{});
} }
void add_device_pool3d_fwd_ndhwc_index_f8_instances( void add_device_pool3d_fwd_ndhwc_index_f8_instances(
...@@ -23,7 +23,7 @@ void add_device_pool3d_fwd_ndhwc_index_f8_instances( ...@@ -23,7 +23,7 @@ void add_device_pool3d_fwd_ndhwc_index_f8_instances(
instances) instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F8, ReduceOpId, true>{}); instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F32, ReduceOpId, true>{});
} }
} // namespace instance } // namespace instance
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -150,7 +150,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification, ...@@ -150,7 +150,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -157,7 +157,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, ...@@ -157,7 +157,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
break; break;
default: default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -174,7 +174,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -174,7 +174,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
break; break;
default: default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -140,7 +140,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, ...@@ -140,7 +140,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -74,8 +74,8 @@ int profile_gemm_impl(int do_verification, ...@@ -74,8 +74,8 @@ int profile_gemm_impl(int do_verification,
switch(init_method) switch(init_method)
{ {
case 0: case 0:
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k); ck::utils::FillConstant<ADataType>{type_convert<ADataType>(1.f)}(a_m_k);
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n); ck::utils::FillConstant<BDataType>{type_convert<BDataType>(1.f)}(b_k_n);
break; break;
case 1: case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
......
...@@ -9,13 +9,38 @@ if (USE_BITINT_EXTENSION_INT4) ...@@ -9,13 +9,38 @@ if (USE_BITINT_EXTENSION_INT4)
endif() endif()
endif() endif()
add_gtest_executable(test_fp8 test_fp8.cpp)
if(result EQUAL 0)
target_link_libraries(test_fp8 PRIVATE utility) add_custom_target(test_fp8)
if (CK_USE_OCP_FP8)
add_gtest_executable(test_fp8_ocp test_fp8_ocp.cpp)
if(result EQUAL 0)
target_link_libraries(test_fp8_ocp PRIVATE utility)
endif()
add_gtest_executable(test_bf8_ocp test_bf8_ocp.cpp)
if(result EQUAL 0)
target_link_libraries(test_bf8_ocp PRIVATE utility)
endif()
add_dependencies(test_fp8 test_fp8_ocp)
add_dependencies(test_fp8 test_bf8_ocp)
endif() endif()
add_gtest_executable(test_bf8 test_bf8.cpp)
if(result EQUAL 0) if (CK_USE_FNUZ_FP8)
target_link_libraries(test_bf8 PRIVATE utility) add_gtest_executable(test_fp8_fnuz test_fp8_fnuz.cpp)
if(result EQUAL 0)
target_link_libraries(test_fp8_fnuz PRIVATE utility)
endif()
add_gtest_executable(test_bf8_fnuz test_bf8_fnuz.cpp)
if(result EQUAL 0)
target_link_libraries(test_bf8_fnuz PRIVATE utility)
endif()
add_dependencies(test_fp8 test_fp8_fnuz)
add_dependencies(test_fp8 test_bf8_fnuz)
endif() endif()
add_gtest_executable(test_custom_type test_custom_type.cpp) add_gtest_executable(test_custom_type test_custom_type.cpp)
......
...@@ -5,158 +5,169 @@ ...@@ -5,158 +5,169 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp" #include "ck/utility/type_convert.hpp"
using ck::bf8_t; using ck::bf8_fnuz_t;
using ck::f8_convert_rne; using ck::f8_convert_rne;
using ck::f8_convert_sr; using ck::f8_convert_sr;
using ck::half_t; using ck::half_t;
using ck::type_convert; using ck::type_convert;
TEST(BF8, NumericLimits) TEST(BF8FNUZ, NumericLimits)
{ {
// constants given for negative zero nan mode // constants given for negative zero nan mode
EXPECT_EQ(ck::NumericLimits<bf8_t>::Min(), type_convert<bf8_t>(0x04)); EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::Min(), type_convert<bf8_fnuz_t>(0x04));
EXPECT_EQ(ck::NumericLimits<bf8_t>::Max(), type_convert<bf8_t>(0x7F)); EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::Max(), type_convert<bf8_fnuz_t>(0x7F));
EXPECT_EQ(ck::NumericLimits<bf8_t>::Lowest(), type_convert<bf8_t>(0xFF)); EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::Lowest(), type_convert<bf8_fnuz_t>(0xFF));
EXPECT_EQ(ck::NumericLimits<bf8_t>::QuietNaN(), type_convert<bf8_t>(0x80)); EXPECT_EQ(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(), type_convert<bf8_fnuz_t>(0x80));
} }
TEST(BF8, ConvertFP32Nearest) TEST(BF8FNUZ, ConvertFP32Nearest)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-6; float abs_tol = 1e-6;
// convert 0 float to bf8 and back, check if holds // convert 0 float to bf8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_t>(0.0f)), abs_tol); ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(0.0f)), abs_tol);
// don't run the next test on gfx11 devices // don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST #ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to bf8 and back, check if holds // convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(), ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_rne<bf8_t>(std::numeric_limits<float>::min())), type_convert<float>(f8_convert_rne<bf8_fnuz_t>(std::numeric_limits<float>::min())),
abs_tol); abs_tol);
#endif #endif
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR(57344.0f, type_convert<float>(f8_convert_rne<bf8_t>(57344.0f)), abs_tol); const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_fnuz_t>::Max());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR(
max_bf8_t_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(max_bf8_t_float)), abs_tol);
// convert maximal float to bf8 and back, check if clipped to 57344.0 // convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(57344.0f, ASSERT_NEAR(max_bf8_t_float,
type_convert<float>(f8_convert_rne<bf8_t>(std::numeric_limits<float>::max())), type_convert<float>(f8_convert_rne<bf8_fnuz_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to bf8_t and check if it is qNan // convert inf float to bf8_fnuz_t and check if it is qNan
ASSERT_NEAR(type_convert<bf8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
f8_convert_rne<bf8_t>(std::numeric_limits<float>::infinity()), f8_convert_rne<bf8_fnuz_t>(std::numeric_limits<float>::infinity()),
abs_tol); abs_tol);
// positive norm float value to bf8 and back, check if holds // positive norm float value to bf8 and back, check if holds
float pos_float = 0.0000762939f; float pos_float = 0.0000762939f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(pos_float)), abs_tol);
// negative norm float value to bf8 and back, check if holds // negative norm float value to bf8 and back, check if holds
float neg_float = -0.0000610351f; float neg_float = -0.0000610351f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(neg_float)), abs_tol);
// positive subnorm float value to bf8 and back, check if holds // positive subnorm float value to bf8 and back, check if holds
pos_float = 0.0000305175f; pos_float = 0.0000305175f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(pos_float)), abs_tol);
// negative subnorm float value to bf8 and back, check if holds // negative subnorm float value to bf8 and back, check if holds
neg_float = -0.0000152587f; neg_float = -0.0000152587f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_fnuz_t>(neg_float)), abs_tol);
} }
TEST(BF8, ConvertFP32Stochastic) TEST(BF8FNUZ, ConvertFP32Stochastic)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-6; float abs_tol = 1e-6;
// convert 0 float to bf8 and back, check if holds // convert 0 float to bf8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_t>(0.0f)), abs_tol); ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(0.0f)), abs_tol);
// convert minimal float to bf8 and back, check if holds // convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(), ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::min())), type_convert<float>(f8_convert_sr<bf8_fnuz_t>(std::numeric_limits<float>::min())),
abs_tol); abs_tol);
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR(57344.0f, type_convert<float>(f8_convert_sr<bf8_t>(57344.0f)), abs_tol); const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_fnuz_t>::Max());
// convert maximal bf8_fnuz_t to float and check if equal to 57344.0
ASSERT_NEAR(
max_bf8_t_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(max_bf8_t_float)), abs_tol);
// convert maximal float to bf8 and back, check if clipped to 57344.0 // convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(57344.0f, ASSERT_NEAR(max_bf8_t_float,
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::max())), type_convert<float>(f8_convert_sr<bf8_fnuz_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to bf8_t and check if it is qNan // convert inf float to bf8_fnuz_t and check if it is qNan
ASSERT_NEAR(type_convert<bf8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
f8_convert_sr<bf8_t>(std::numeric_limits<float>::infinity()), f8_convert_sr<bf8_fnuz_t>(std::numeric_limits<float>::infinity()),
abs_tol); abs_tol);
// positive norm float value to bf8 and back, check if holds // positive norm float value to bf8 and back, check if holds
float pos_float = 0.0000762939f; float pos_float = 0.0000762939f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(pos_float)), abs_tol);
// negative norm float value to bf8 and back, check if holds // negative norm float value to bf8 and back, check if holds
float neg_float = -0.0000610351f; float neg_float = -0.0000610351f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(neg_float)), abs_tol);
// positive subnorm float value to bf8 and back, check if holds // positive subnorm float value to bf8 and back, check if holds
pos_float = 0.0000305175f; pos_float = 0.0000305175f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(pos_float)), abs_tol);
// negative subnorm float value to bf8 and back, check if holds // negative subnorm float value to bf8 and back, check if holds
neg_float = -0.0000152587f; neg_float = -0.0000152587f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_fnuz_t>(neg_float)), abs_tol);
} }
TEST(BF8, ConvertFP16Nearest) TEST(BF8FNUZ, ConvertFP16Nearest)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-3; float abs_tol = 1e-3;
// convert 0 fp16 to bf8 and back, check if holds // convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<bf8_t>(half_t{0.0})), abs_tol); ASSERT_NEAR(
half_t{0.0}, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to bf8 and back, check if holds // convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(), ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::Min())), type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
abs_tol); abs_tol);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const auto max_bf8_t_half = type_convert<half_t>(ck::NumericLimits<bf8_fnuz_t>::Max());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR( ASSERT_NEAR(
half_t{57344.0}, type_convert<half_t>(f8_convert_rne<bf8_t>(half_t{57344.0})), abs_tol); max_bf8_t_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(max_bf8_t_half)), abs_tol);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0 // convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(half_t{57344.0}, ASSERT_NEAR(max_bf8_t_half,
type_convert<half_t>(f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::Max())), type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<bf8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()), f8_convert_rne<bf8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol); abs_tol);
// positive norm fp16 value to bf8 and back, check if holds // positive norm fp16 value to bf8 and back, check if holds
half_t pos_half = half_t{0.0000762939}; half_t pos_half = half_t{0.0000762939};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(pos_half)), abs_tol);
// negative norm fp16 value to bf8 and back, check if holds // negative norm fp16 value to bf8 and back, check if holds
half_t neg_half = half_t{-0.0000610351}; half_t neg_half = half_t{-0.0000610351};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to bf8 and back, check if holds // positive subnorm fp16 value to bf8 and back, check if holds
pos_half = half_t{0.0000305175}; pos_half = half_t{0.0000305175};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to bf8 and back, check if holds // negative subnorm fp16 value to bf8 and back, check if holds
neg_half = half_t{-0.0000152587}; neg_half = half_t{-0.0000152587};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_fnuz_t>(neg_half)), abs_tol);
} }
TEST(BF8, ConvertFP16Stochastic) TEST(BF8FNUZ, ConvertFP16Stochastic)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-3; float abs_tol = 1e-3;
// convert 0 fp16 to bf8 and back, check if holds // convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{0.0})), abs_tol); ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to bf8 and back, check if holds // convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(), ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Min())), type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
abs_tol); abs_tol);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
const auto max_bf8_t_half = type_convert<half_t>(ck::NumericLimits<bf8_fnuz_t>::Max());
// convert maximal bf8_fnuz_t to fp16 and check if equal to 57344.0
ASSERT_NEAR( ASSERT_NEAR(
half_t{57344.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{57344.0})), abs_tol); max_bf8_t_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(max_bf8_t_half)), abs_tol);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0 // convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(half_t{57344.0}, ASSERT_NEAR(max_bf8_t_half,
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Max())), type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<bf8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<bf8_fnuz_t>::QuietNaN(),
f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()), f8_convert_sr<bf8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol); abs_tol);
// positive norm fp16 value to bf8 and back, check if holds // positive norm fp16 value to bf8 and back, check if holds
half_t pos_half = half_t{0.0000762939}; half_t pos_half = half_t{0.0000762939};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(pos_half)), abs_tol);
// negative norm fp16 value to bf8 and back, check if holds // negative norm fp16 value to bf8 and back, check if holds
half_t neg_half = half_t{-0.0000610351}; half_t neg_half = half_t{-0.0000610351};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to bf8 and back, check if holds // positive subnorm fp16 value to bf8 and back, check if holds
pos_half = half_t{0.0000305175}; pos_half = half_t{0.0000305175};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to bf8 and back, check if holds // negative subnorm fp16 value to bf8 and back, check if holds
neg_half = half_t{-0.0000152587}; neg_half = half_t{-0.0000152587};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_fnuz_t>(neg_half)), abs_tol);
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bf8_ocp_t;
using ck::f8_convert_rne;
using ck::f8_convert_sr;
using ck::half_t;
using ck::type_convert;
TEST(BF8OCP, NumericLimits)
{ // constants given for OCP FP8
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::Min(),
type_convert<bf8_ocp_t>(0x04)); // 0b00000100 = 2^-14
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::Max(),
type_convert<bf8_ocp_t>(0x7B)); // 0b01111011 = 57344
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::Lowest(),
type_convert<bf8_ocp_t>(0xFB)); // 0b11111011 = -57344
EXPECT_EQ(ck::NumericLimits<bf8_ocp_t>::QuietNaN().data,
type_convert<bf8_ocp_t>(0x7D).data); // 0b01111101
EXPECT_FALSE(ck::NumericLimits<bf8_ocp_t>::QuietNaN() ==
ck::NumericLimits<bf8_ocp_t>::QuietNaN());
EXPECT_TRUE(ck::fp8_is_inf(type_convert<bf8_ocp_t>(0xFC)) &&
ck::fp8_is_inf(type_convert<bf8_ocp_t>(0x7C)));
}
TEST(BF8OCP, ConvertFP32Nearest)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_ocp_t>(0.0f)), 0.0f);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::min())),
abs_tol);
const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR(
max_bf8_t_float, type_convert<float>(f8_convert_rne<bf8_ocp_t>(max_bf8_t_float)), 0.0f);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR(max_bf8_t_float,
type_convert<float>(f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::max())),
0.0f);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(ck::NumericLimits<bf8_ocp_t>::Max(),
f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::infinity()));
// positive normal float value to bf8 and back, check if holds
float pos_float = 0.0000762939f; // 10*2^-17
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_ocp_t>(pos_float)), abs_tol);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr auto neg_min_bf8 = -0.00006103515625f; //-2^-14
ASSERT_NEAR(neg_min_bf8, type_convert<float>(f8_convert_rne<bf8_ocp_t>(neg_min_bf8)), 0.0f);
// positive subnorm float value to bf8 and back, check if holds
constexpr auto pos_subnorm_bf8 = 0.000030517578125f; // 2^-15
ASSERT_NEAR(
pos_subnorm_bf8, type_convert<float>(f8_convert_rne<bf8_ocp_t>(pos_subnorm_bf8)), 0.0f);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr auto min_subnorm_bf8 = -0.0000152587890625f; //-2^-16
ASSERT_NEAR(
min_subnorm_bf8, type_convert<float>(f8_convert_rne<bf8_ocp_t>(min_subnorm_bf8)), 0.0f);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr auto less_than_min_subnorm = 0.00000762939453125f; // 2^-17
ASSERT_EQ(0.0f, type_convert<float>(f8_convert_rne<bf8_ocp_t>(less_than_min_subnorm)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const auto bf8_nan = f8_convert_rne<bf8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
TEST(BF8OCP, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to bfp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_ocp_t>(0.0f)), 0.0f);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::min())),
abs_tol);
const auto max_bf8_t_float = type_convert<float>(ck::NumericLimits<bf8_ocp_t>::Max());
// convert maximal bf8_ocp_t to float and check if equal to bf8 max
ASSERT_NEAR(
max_bf8_t_float, type_convert<float>(f8_convert_sr<bf8_ocp_t>(max_bf8_t_float)), 0.0f);
// convert maximal float to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR(max_bf8_t_float,
type_convert<float>(f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::max())),
0.0f);
// convert float infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(ck::NumericLimits<bf8_ocp_t>::Max(),
f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::infinity()));
// positive normal float value to bf8 and back, check if holds
float pos_float = 0.0000762939f; // 10*2^-17
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_ocp_t>(pos_float)), abs_tol);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr auto neg_min_bf8 = -0.00006103515625f; //-2^-14
ASSERT_NEAR(neg_min_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(neg_min_bf8)), 0.0f);
// positive subnorm float value to bf8 and back, check if holds
constexpr auto pos_subnorm_bf8 = 0.000030517578125f; // 2^-15
ASSERT_NEAR(
pos_subnorm_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(pos_subnorm_bf8)), 0.0f);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr auto min_subnorm_bf8 = -0.0000152587890625f; //-2^-16
ASSERT_NEAR(
min_subnorm_bf8, type_convert<float>(f8_convert_sr<bf8_ocp_t>(min_subnorm_bf8)), 0.0f);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr auto less_than_min_subnorm = 0.00000762939453125f; // 2^-17
ASSERT_NEAR(0.0f,
type_convert<float>(f8_convert_sr<bf8_ocp_t>(less_than_min_subnorm)),
0.0000152587890625f);
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const auto bf8_nan = f8_convert_sr<bf8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
TEST(BF8OCP, ConvertFP16Nearest)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(ck::NumericLimits<half_t>::Min())),
half_t_tol);
const auto max_bf8_t_half_t = type_convert<half_t>(ck::NumericLimits<bf8_ocp_t>::Max());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR(max_bf8_t_half_t,
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(max_bf8_t_half_t)),
half_t_zero);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR(max_bf8_t_half_t,
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<bf8_ocp_t>::Max(),
f8_convert_rne<bf8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr half_t pos_norm_bf8{0.0000762939f}; // 10*2^-17
ASSERT_NEAR(
pos_norm_bf8, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(pos_norm_bf8)), half_t_tol);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr half_t neg_min_bf8{-0.00006103515625f}; //-2^-14
ASSERT_NEAR(
neg_min_bf8, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(neg_min_bf8)), half_t_zero);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr half_t pos_subnorm_bf8{0.000030517578125f}; // 2^-15
ASSERT_NEAR(pos_subnorm_bf8,
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(pos_subnorm_bf8)),
half_t_zero);
// min subnorm bf8 value to bf8 and back, check if holds
constexpr half_t min_subnorm_bf8{-0.0000152587890625f}; //-2^-16
ASSERT_NEAR(min_subnorm_bf8,
type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(min_subnorm_bf8)),
half_t_zero);
// smaller than min subnorm bf8 value to bf8 must be zero
constexpr half_t less_than_min_subnorm{0.00000762939453125f}; // 2^-17
ASSERT_EQ(half_t_zero, type_convert<half_t>(f8_convert_rne<bf8_ocp_t>(less_than_min_subnorm)));
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const auto bf8_nan = f8_convert_rne<bf8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
TEST(BF8OCP, ConvertFP16Stochastic)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
constexpr auto min_subnorm_bf8 = 0.0000152587890625f; // 2^-16
// convert 0 half_t to bfp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t (6.103515625e-05) to fp8 and back
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::Min())),
half_t_zero);
const auto max_bf8_t_half_t = type_convert<half_t>(ck::NumericLimits<bf8_ocp_t>::Max());
// convert maximal bf8_ocp_t to half_t and check if equal to bf8 max
ASSERT_NEAR(max_bf8_t_half_t,
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(max_bf8_t_half_t)),
half_t_zero);
// convert maximal half_t to bf8 and back, check if clipped to bf8 max (saturation to finite)
ASSERT_NEAR(max_bf8_t_half_t,
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to bf8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<bf8_ocp_t>::Max(),
f8_convert_sr<bf8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive normal bf8 value to bf8 and back, check if holds
constexpr half_t pos_norm_bf8{0.0000762939f}; // 10*2^-17
ASSERT_NEAR(
pos_norm_bf8, type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(pos_norm_bf8)), half_t_tol);
// negative smallest normal bf8 value to bf8 and back, check if holds
constexpr half_t neg_min_bf8{-0.00006103515625f}; //-2^-14
ASSERT_NEAR(
neg_min_bf8, type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(neg_min_bf8)), half_t_zero);
// positive subnorm bf8 value to bf8 and back, check if holds
constexpr half_t pos_subnorm_bf8{0.000030517578125f}; // 2^-15
ASSERT_NEAR(pos_subnorm_bf8,
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(pos_subnorm_bf8)),
half_t_zero);
// min subnorm bf8 value to bf8 and back, check if holds
ASSERT_NEAR(half_t{-min_subnorm_bf8},
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(half_t{-min_subnorm_bf8})),
half_t_zero);
// smaller than min subnorm bf8 value to bf8 alternates between 0 and 2^-16
constexpr half_t less_than_min_subnorm{0.00000762939453125f}; // 2^-17
ASSERT_NEAR(half_t_zero,
type_convert<half_t>(f8_convert_sr<bf8_ocp_t>(less_than_min_subnorm)),
half_t{min_subnorm_bf8});
// convert quiet NaN to bf8_ocp_t and check if it is quiet NaN
const auto bf8_nan = f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
...@@ -872,3 +872,161 @@ TEST(Complex_half, TestAsTypeReshape) ...@@ -872,3 +872,161 @@ TEST(Complex_half, TestAsTypeReshape)
test_vec.at(num_elem * i + 1)); test_vec.at(num_elem * i + 1));
}); });
} }
#if CK_USE_OCP_FP8
TEST(FP8OCP, TestSize)
{
static_assert(std::is_same_v<f8_t, ck::f8_ocp_t>, "OCP FP8 is not enabled");
ASSERT_EQ(sizeof(f8_t), sizeof(ck::fp8_storage_t));
ASSERT_EQ(sizeof(vector_type<f8_t, 2>), sizeof(vector_type<ck::fp8_storage_t, 2>));
ASSERT_EQ(sizeof(vector_type<f8_t, 4>), sizeof(vector_type<ck::fp8_storage_t, 4>));
ASSERT_EQ(sizeof(vector_type<f8_t, 8>), sizeof(vector_type<ck::fp8_storage_t, 8>));
ASSERT_EQ(sizeof(vector_type<f8_t, 16>), sizeof(vector_type<ck::fp8_storage_t, 16>));
ASSERT_EQ(sizeof(vector_type<f8_t, 32>), sizeof(vector_type<ck::fp8_storage_t, 32>));
ASSERT_EQ(sizeof(vector_type<f8_t, 64>), sizeof(vector_type<ck::fp8_storage_t, 64>));
}
TEST(FP8OCP, TestAsType)
{
static_assert(std::is_same_v<f8_t, ck::f8_ocp_t>, "OCP FP8 is not enabled");
// test size
std::array<float, 8> test_vec = {-4, -2, -0.5, -0.25, 1.0 / 8.0, 1, 1.5, 16};
constexpr int size = test_vec.size();
// reference vector
vector_type<f8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}(
[&](auto i) { ASSERT_EQ(right_vec.template AsType<f8_t>()(Number<i>{}), f8_t{0}); });
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f8_t>()(Number<i>{}) = ck::type_convert<f8_t>(test_vec.at(i));
});
// copy the vector
vector_type<f8_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f8_t>()(Number<i>{}),
ck::type_convert<f8_t>(test_vec.at(i)));
});
ck::non_native_vector_base<ck::f8_ocp_t, 2> nnvb_f8x2(ck::type_convert<f8_t>(-10.0f));
ASSERT_EQ(nnvb_f8x2.template AsType<f8_t>()(Number<0>{}), ck::type_convert<f8_t>(-10.0f));
ASSERT_EQ(nnvb_f8x2.template AsType<f8_t>()(Number<1>{}), ck::type_convert<f8_t>(-10.0f));
}
TEST(FP8OCP, TestAsTypeReshape)
{
static_assert(std::is_same_v<f8_t, ck::f8_ocp_t>, "OCP FP8 is not enabled");
// test size
std::array<float, 8> test_vec = {-8, -0.5, -0.25, 1.0 / 8.0, 1 / 256, 1, 1.5, 16};
constexpr int size = test_vec.size();
// reference vector
vector_type<f8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}(
[&](auto i) { ASSERT_EQ(right_vec.template AsType<f8_t>()(Number<i>{}), f8_t{0}); });
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<f8_t>()(Number<i>{}) = ck::type_convert<f8_t>(test_vec.at(i));
});
// copy the first half of a vector
vector_type<f8_t, size / 2> left_vec{
right_vec.template AsType<vector_type<f8_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f8_t>()(Number<i>{}),
ck::type_convert<f8_t>(test_vec.at(i)));
});
}
TEST(BF8OCP, TestSize)
{
static_assert(std::is_same_v<bf8_t, ck::bf8_ocp_t>, "OCP BF8 is not enabled");
ASSERT_EQ(sizeof(bf8_t), sizeof(ck::fp8_storage_t));
ASSERT_EQ(sizeof(vector_type<bf8_t, 2>), sizeof(vector_type<ck::fp8_storage_t, 2>));
ASSERT_EQ(sizeof(vector_type<bf8_t, 4>), sizeof(vector_type<ck::fp8_storage_t, 4>));
ASSERT_EQ(sizeof(vector_type<bf8_t, 8>), sizeof(vector_type<ck::fp8_storage_t, 8>));
ASSERT_EQ(sizeof(vector_type<bf8_t, 16>), sizeof(vector_type<ck::fp8_storage_t, 16>));
ASSERT_EQ(sizeof(vector_type<bf8_t, 32>), sizeof(vector_type<ck::fp8_storage_t, 32>));
ASSERT_EQ(sizeof(vector_type<bf8_t, 64>), sizeof(vector_type<ck::fp8_storage_t, 64>));
}
TEST(BF8OCP, TestAsType)
{
static_assert(std::is_same_v<bf8_t, ck::bf8_ocp_t>, "OCP BF8 is not enabled");
// test size
std::array<float, 8> test_vec = {-4, -2, -0.5, -0.25, 1.0 / 8.0, 1, 1.5, 16};
constexpr int size = test_vec.size();
// reference vector
vector_type<bf8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}(
[&](auto i) { ASSERT_EQ(right_vec.template AsType<bf8_t>()(Number<i>{}), bf8_t{0}); });
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<bf8_t>()(Number<i>{}) = ck::type_convert<bf8_t>(test_vec.at(i));
});
// copy the vector
vector_type<bf8_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<bf8_t>()(Number<i>{}),
ck::type_convert<bf8_t>(test_vec.at(i)));
});
ck::non_native_vector_base<bf8_t, 2> nnvb_bf8x2(ck::type_convert<bf8_t>(-10.0f));
ASSERT_EQ(nnvb_bf8x2.template AsType<bf8_t>()(Number<0>{}), ck::type_convert<bf8_t>(-10.0f));
ASSERT_EQ(nnvb_bf8x2.template AsType<bf8_t>()(Number<1>{}), ck::type_convert<bf8_t>(-10.0f));
}
TEST(BF8OCP, TestAsTypeReshape)
{
static_assert(std::is_same_v<bf8_t, ck::bf8_ocp_t>, "OCP BF8 is not enabled");
// test size
std::array<float, 8> test_vec = {-8, -0.5, -0.25, 1.0 / 8.0, 1 / 256, 1, 1.5, 16};
constexpr int size = test_vec.size();
// reference vector
vector_type<bf8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}(
[&](auto i) { ASSERT_EQ(right_vec.template AsType<bf8_t>()(Number<i>{}), bf8_t{0}); });
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<bf8_t>()(Number<i>{}) = ck::type_convert<bf8_t>(test_vec.at(i));
});
// copy the first half of a vector
vector_type<bf8_t, size / 2> left_vec{
right_vec.template AsType<vector_type<bf8_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<bf8_t>()(Number<i>{}),
ck::type_convert<bf8_t>(test_vec.at(i)));
});
}
#endif
...@@ -7,154 +7,171 @@ ...@@ -7,154 +7,171 @@
using ck::f8_convert_rne; using ck::f8_convert_rne;
using ck::f8_convert_sr; using ck::f8_convert_sr;
using ck::f8_t; using ck::f8_fnuz_t;
using ck::half_t; using ck::half_t;
using ck::type_convert; using ck::type_convert;
TEST(FP8, NumericLimits) TEST(FP8FNUZ, NumericLimits)
{ {
// constants given for negative zero nan mode // constants given for negative zero nan mode
EXPECT_EQ(ck::NumericLimits<f8_t>::Min(), type_convert<f8_t>(0x08)); EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::Min(), type_convert<f8_fnuz_t>(0x08));
EXPECT_EQ(ck::NumericLimits<f8_t>::Max(), type_convert<f8_t>(0x7F)); EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::Max(), type_convert<f8_fnuz_t>(0x7F));
EXPECT_EQ(ck::NumericLimits<f8_t>::Lowest(), type_convert<f8_t>(0xFF)); EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::Lowest(), type_convert<f8_fnuz_t>(0xFF));
EXPECT_EQ(ck::NumericLimits<f8_t>::QuietNaN(), type_convert<f8_t>(0x80)); EXPECT_EQ(ck::NumericLimits<f8_fnuz_t>::QuietNaN(), type_convert<f8_fnuz_t>(0x80));
} }
TEST(FP8, ConvertFP32Nearest) TEST(FP8FNUZ, ConvertFP32Nearest)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-6; float abs_tol = 1e-6;
// convert 0 float to fp8 and back, check if holds // convert 0 float to fp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<f8_t>(0.0f)), abs_tol); ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<f8_fnuz_t>(0.0f)), abs_tol);
// don't run the next test on gfx11 devices // don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST #ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to fp8 and back, check if holds // convert minimal float to fp8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(), ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_rne<f8_t>(std::numeric_limits<float>::min())), type_convert<float>(f8_convert_rne<f8_fnuz_t>(std::numeric_limits<float>::min())),
abs_tol); abs_tol);
#endif #endif
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR(240.0f, type_convert<float>(f8_convert_rne<f8_t>(240.0f)), abs_tol); const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_fnuz_t>::Max());
// convert maximal float to fp8 and back, check if clipped to 240.0 // convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR(240.0f, ASSERT_NEAR(
type_convert<float>(f8_convert_rne<f8_t>(std::numeric_limits<float>::max())), max_f8_t_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(max_f8_t_float)), abs_tol);
// XXX: FNUZ f8_convert_rne behavior is inconsistent.
// Clipping large values to fp8 max (saturation to finite) contradicts converting inf float to
// fp8 qNAN (no saturation).
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR(max_f8_t_float,
type_convert<float>(f8_convert_rne<f8_fnuz_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to f8_t and check if it is qNan // convert inf float to f8_fnuz_t and check if it is qNan
ASSERT_NEAR(type_convert<f8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
f8_convert_rne<f8_t>(std::numeric_limits<float>::infinity()), f8_convert_rne<f8_fnuz_t>(std::numeric_limits<float>::infinity()),
abs_tol); abs_tol);
// positive norm float value to fp8 and back, check if holds // positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f; float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(pos_float)), abs_tol);
// negative norm float value to fp8 and back, check if holds // negative norm float value to fp8 and back, check if holds
float neg_float = -0.015625f; float neg_float = -0.015625f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(neg_float)), abs_tol);
// positive subnorm float value to fp8 and back, check if holds // positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f; pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(pos_float)), abs_tol);
// negative subnorm float value to fp8 and back, check if holds // negative subnorm float value to fp8 and back, check if holds
neg_float = -0.001953125f; neg_float = -0.001953125f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_fnuz_t>(neg_float)), abs_tol);
} }
TEST(FP8, ConvertFP32Stochastic) TEST(FP8FNUZ, ConvertFP32Stochastic)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-6; float abs_tol = 1e-6;
// convert 0 float to fp8 and back, check if holds // convert 0 float to fp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_t>(0.0f)), abs_tol); ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_fnuz_t>(0.0f)), abs_tol);
// convert minimal float to fp8 and back, check if holds // convert minimal float to fp8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(), ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::min())), type_convert<float>(f8_convert_sr<f8_fnuz_t>(std::numeric_limits<float>::min())),
abs_tol); abs_tol);
// convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR(240.0f, type_convert<float>(f8_convert_sr<f8_t>(240.0f)), abs_tol); const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_fnuz_t>::Max());
// convert maximal float to fp8 and back, check if clipped to 240.0 // convert maximal f8_fnuz_t to float and check if equal to fp8 max
ASSERT_NEAR(240.0f, ASSERT_NEAR(
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::max())), max_f8_t_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(max_f8_t_float)), abs_tol);
// convert maximal float to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR(max_f8_t_float,
type_convert<float>(f8_convert_sr<f8_fnuz_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to f8_t and check if it is qNan // convert inf float to f8_fnuz_t and check if it is qNan
ASSERT_NEAR(type_convert<f8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
f8_convert_sr<f8_t>(std::numeric_limits<float>::infinity()), f8_convert_sr<f8_fnuz_t>(std::numeric_limits<float>::infinity()),
abs_tol); abs_tol);
// positive norm float value to fp8 and back, check if holds // positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f; float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(pos_float)), abs_tol);
// negative norm float value to fp8 and back, check if holds // negative norm float value to fp8 and back, check if holds
float neg_float = -0.015625f; float neg_float = -0.015625f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(neg_float)), abs_tol);
// positive subnorm float value to fp8 and back, check if holds // positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f; pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(pos_float)), abs_tol);
// negative subnorm float value to fp8 and back, check if holds // negative subnorm float value to fp8 and back, check if holds
neg_float = -0.001953125f; neg_float = -0.001953125f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_fnuz_t>(neg_float)), abs_tol);
} }
TEST(FP8, ConvertFP16Nearest) TEST(FP8FNUZ, ConvertFP16Nearest)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-3; float abs_tol = 1e-3;
// convert 0 fp16 to fp8 and back, check if holds // convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<f8_t>(half_t{0.0})), abs_tol); ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to fp8 and back, check if holds // convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(), ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::Min())), type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
abs_tol); abs_tol);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(f8_convert_rne<f8_t>(half_t{240.0})), abs_tol); const auto max_f8_t_half = type_convert<half_t>(ck::NumericLimits<f8_fnuz_t>::Max());
// convert maximal fp16 to fp8 and back, check if clipped to 240.0 // convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR(half_t{240.0}, ASSERT_NEAR(
type_convert<half_t>(f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::Max())), max_f8_t_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(max_f8_t_half)), abs_tol);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR(max_f8_t_half,
type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<f8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), f8_convert_rne<f8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol); abs_tol);
// positive norm fp16 value to fp8 and back, check if holds // positive norm fp16 value to fp8 and back, check if holds
half_t pos_half = half_t{0.017578125}; half_t pos_half = half_t{0.017578125};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(pos_half)), abs_tol);
// negative norm fp16 value to fp8 and back, check if holds // negative norm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.015625}; half_t neg_half = half_t{-0.015625};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to fp8 and back, check if holds // positive subnorm fp16 value to fp8 and back, check if holds
pos_half = half_t{0.00390625}; pos_half = half_t{0.00390625};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to fp8 and back, check if holds // negative subnorm fp16 value to fp8 and back, check if holds
neg_half = half_t{-0.001953125}; neg_half = half_t{-0.001953125};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_fnuz_t>(neg_half)), abs_tol);
} }
TEST(FP8, ConvertFP16Stochastic) TEST(FP8FNUZ, ConvertFP16Stochastic)
{ {
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-3; float abs_tol = 1e-3;
// convert 0 fp16 to fp8 and back, check if holds // convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<f8_t>(half_t{0.0})), abs_tol); ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to fp8 and back, check if holds // convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(), ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Min())), type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(ck::NumericLimits<half_t>::Min())),
abs_tol); abs_tol);
// convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(f8_convert_sr<f8_t>(half_t{240.0})), abs_tol); const auto max_f8_t_half = type_convert<half_t>(ck::NumericLimits<f8_fnuz_t>::Max());
// convert maximal fp16 to fp8 and back, check if clipped to 240.0 // convert maximal f8_fnuz_t to fp16 and check if equal to fp8 max
ASSERT_NEAR(half_t{240.0}, ASSERT_NEAR(
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Max())), max_f8_t_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(max_f8_t_half)), abs_tol);
// convert maximal fp16 to fp8 and back, check if clipped to fp8 max
ASSERT_NEAR(max_f8_t_half,
type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<f8_t>(0x80), ASSERT_NEAR(ck::NumericLimits<f8_fnuz_t>::QuietNaN(),
f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), f8_convert_sr<f8_fnuz_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol); abs_tol);
// positive norm fp16 value to fp8 and back, check if holds // positive norm fp16 value to fp8 and back, check if holds
half_t pos_half = half_t{0.017578125}; half_t pos_half = half_t{0.017578125};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(pos_half)), abs_tol);
// negative norm fp16 value to fp8 and back, check if holds // negative norm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.015625}; half_t neg_half = half_t{-0.015625};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to fp8 and back, check if holds // positive subnorm fp16 value to fp8 and back, check if holds
pos_half = half_t{0.00390625}; pos_half = half_t{0.00390625};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to fp8 and back, check if holds // negative subnorm fp16 value to fp8 and back, check if holds
neg_half = half_t{-0.001953125}; neg_half = half_t{-0.001953125};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_fnuz_t>(neg_half)), abs_tol);
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::f8_convert_rne;
using ck::f8_convert_sr;
using ck::f8_ocp_t;
using ck::half_t;
using ck::type_convert;
TEST(FP8OCP, NumericLimits)
{
// constants given for OCP FP8
EXPECT_EQ(ck::NumericLimits<f8_ocp_t>::Min(),
type_convert<f8_ocp_t>(0x08)); // 0b00001000 = 2^-6
EXPECT_EQ(ck::NumericLimits<f8_ocp_t>::Max(), type_convert<f8_ocp_t>(0x7E)); // 0b01111110 = 448
EXPECT_EQ(ck::NumericLimits<f8_ocp_t>::Lowest(),
type_convert<f8_ocp_t>(0xFE)); // 0b11111110 = -448
EXPECT_EQ(ck::NumericLimits<f8_ocp_t>::QuietNaN().data,
type_convert<f8_ocp_t>(0x7F).data); // 0b01111111
EXPECT_FALSE(ck::NumericLimits<f8_ocp_t>::QuietNaN() ==
ck::NumericLimits<f8_ocp_t>::QuietNaN());
}
TEST(FP8OCP, ConvertFP32Nearest)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<f8_ocp_t>(0.0f)), 0.0f);
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::min())),
abs_tol);
const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to float and check if equal to fp8 max
ASSERT_NEAR(
max_f8_t_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(max_f8_t_float)), 0.0f);
// convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR(max_f8_t_float,
type_convert<float>(f8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::max())),
0.0f);
// convert float infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(ck::NumericLimits<f8_ocp_t>::Max(),
f8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::infinity()));
// positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(pos_float)), abs_tol);
// smallest normal fp8 value to fp8 and back, check if holds
float neg_float = -0.015625f; //-2^-6
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(neg_float)), 0.0f);
// positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(pos_float)), abs_tol);
// min subnorm fp8 value to fp8 and back, check if holds
neg_float = -0.001953125f; //-2^-9
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_ocp_t>(neg_float)), 0.0f);
// smaller than min subnorm fp8 value to fp8 must be zero
auto less_than_min_subnorm = 0.0009765625f; // 2^-10
ASSERT_EQ(0.0f, type_convert<float>(f8_convert_rne<f8_ocp_t>(less_than_min_subnorm)));
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto f8_nan = f8_convert_rne<f8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f);
}
TEST(FP8OCP, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to fp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<f8_ocp_t>(0.0f)), 0.0f);
// convert minimal float to fp8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::min())),
abs_tol);
const auto max_f8_t_float = type_convert<float>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to float and check if equal to fp8 max
ASSERT_NEAR(max_f8_t_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(max_f8_t_float)), 0.0f);
// convert maximal float to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR(max_f8_t_float,
type_convert<float>(f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::max())),
0.0f);
// convert float infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(ck::NumericLimits<f8_ocp_t>::Max(),
f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::infinity()));
// positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(pos_float)), abs_tol);
// smallest normal fp8 value to fp8 and back, check if holds
float neg_float = -0.015625f; //-2^-6
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(neg_float)), 0.0f);
// positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_ocp_t>(pos_float)), abs_tol);
// min subnorm fp8 value to fp8 and back, check if holds
constexpr auto min_subnorm_fp8 = -0.001953125f; //-2^-9
ASSERT_NEAR(
min_subnorm_fp8, type_convert<float>(f8_convert_sr<f8_ocp_t>(min_subnorm_fp8)), 0.0f);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto less_than_min_subnorm = 0.0009765625f; // 2^-10
ASSERT_NEAR(
0.0f, type_convert<float>(f8_convert_sr<f8_ocp_t>(less_than_min_subnorm)), 0.001953125f);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto f8_nan = f8_convert_sr<f8_ocp_t>(std::numeric_limits<float>::quiet_NaN());
ASSERT_TRUE((f8_nan.data & 0x7f) == 0x7f);
}
TEST(FP8OCP, ConvertFP16Nearest)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t to fp8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::Min())),
half_t_tol);
const auto max_f8_t_half_t = type_convert<half_t>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(max_f8_t_half_t)),
half_t_zero);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<f8_ocp_t>::Max(),
f8_convert_rne<f8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive norm half_t value to fp8 and back, check if holds
half_t pos_half_t{0.017578125f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(pos_half_t)), half_t_tol);
// smallest normal fp8 value to fp8 and back, check if holds
half_t neg_half_t{-0.015625f}; //-2^-6
ASSERT_NEAR(
neg_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(neg_half_t)), half_t_zero);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t = half_t{0.00390625f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(pos_half_t)), half_t_tol);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t = half_t{-0.001953125f}; //-2^-9
ASSERT_NEAR(
neg_half_t, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(neg_half_t)), half_t_zero);
// smaller than min subnorm fp8 value to fp8 must be zero
auto less_than_min_subnorm = half_t{0.0009765625f}; // 2^-10
ASSERT_EQ(half_t_zero, type_convert<half_t>(f8_convert_rne<f8_ocp_t>(less_than_min_subnorm)));
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto f8_nan = f8_convert_rne<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data));
}
TEST(FP8OCP, ConvertFP16Stochastic)
{
// fix the tolerance value
constexpr half_t half_t_tol = 1e-3;
constexpr half_t half_t_zero = 0.0;
constexpr auto min_subnorm_fp8 = 0.001953125f; // 2^-9
// convert 0 half_t to fp8 and back, check if holds
ASSERT_NEAR(
half_t_zero, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(half_t_zero)), half_t_zero);
// convert minimal half_t (6.103515625e-05) to fp8 and back
// alternates between 0 and 2^-9 (0.001953125)
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::Min())),
type_convert<half_t>(min_subnorm_fp8));
const auto max_f8_t_half_t = type_convert<half_t>(ck::NumericLimits<f8_ocp_t>::Max());
// convert maximal f8_ocp_t to half_t and check if equal to fp8 max
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(max_f8_t_half_t)),
half_t_zero);
// convert maximal half_t to fp8 and back, check if clipped to fp8 max (saturation to finite)
ASSERT_NEAR(max_f8_t_half_t,
type_convert<half_t>(f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::Max())),
half_t_zero);
// convert half_t infinity to f8_ocp_t and check if it is max value (saturation to finite)
ASSERT_EQ(
ck::NumericLimits<f8_ocp_t>::Max(),
f8_convert_sr<f8_ocp_t>(type_convert<half_t>(std::numeric_limits<float>::infinity())));
// positive norm half_t value to fp8 and back, check if holds
half_t pos_half_t{0.017578125f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(pos_half_t)), half_t_tol);
// smallest normal fp8 value to fp8 and back, check if holds
half_t neg_half_t{-0.015625f}; //-2^-6
ASSERT_NEAR(neg_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(neg_half_t)), half_t_zero);
// positive subnorm half_t value to fp8 and back, check if holds
pos_half_t = half_t{0.00390625f};
ASSERT_NEAR(pos_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(pos_half_t)), half_t_tol);
// min subnorm fp8 value to fp8 and back, check if holds
neg_half_t = half_t{-min_subnorm_fp8}; //-2^-9
ASSERT_NEAR(neg_half_t, type_convert<half_t>(f8_convert_sr<f8_ocp_t>(neg_half_t)), half_t_zero);
// smaller than min subnorm fp8 value to fp8 alternates between 0 and 2^-9
auto less_than_min_subnorm = half_t{0.0009765625f}; // 2^-10
ASSERT_NEAR(
type_convert<float>(half_t_zero),
type_convert<float>(type_convert<half_t>(f8_convert_sr<f8_ocp_t>(less_than_min_subnorm))),
min_subnorm_fp8);
// convert quiet NaN to f8_ocp_t and check if it is quiet NaN
auto f8_nan = f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data));
}
...@@ -138,7 +138,7 @@ TYPED_TEST_SUITE(AvgPool2D_BF16, AvgPool2D_BF16_Types); ...@@ -138,7 +138,7 @@ TYPED_TEST_SUITE(AvgPool2D_BF16, AvgPool2D_BF16_Types);
TYPED_TEST_SUITE(AvgPool2D_I8, AvgPool2D_I8_Types); TYPED_TEST_SUITE(AvgPool2D_I8, AvgPool2D_I8_Types);
TYPED_TEST_SUITE(AvgPool2D_F8, AvgPool2D_F8_Types); TYPED_TEST_SUITE(AvgPool2D_F8, AvgPool2D_F8_Types);
TYPED_TEST(AvgPool2D_F32, AvgPool2D_I8_Test) { this->Run(); } TYPED_TEST(AvgPool2D_F32, AvgPool2D_F32_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_F16, AvgPool2D_F16_Test) { this->Run(); } TYPED_TEST(AvgPool2D_F16, AvgPool2D_F16_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_BF16, AvgPool2D_BF16_Test) { this->Run(); } TYPED_TEST(AvgPool2D_BF16, AvgPool2D_BF16_Test) { this->Run(); }
TYPED_TEST(AvgPool2D_I8, AvgPool2D_I8_Test) { this->Run(); } TYPED_TEST(AvgPool2D_I8, AvgPool2D_I8_Test) { this->Run(); }
......
...@@ -143,7 +143,7 @@ TYPED_TEST_SUITE(MaxPool2D_BF16, MaxPool2D_BF16_Types); ...@@ -143,7 +143,7 @@ TYPED_TEST_SUITE(MaxPool2D_BF16, MaxPool2D_BF16_Types);
TYPED_TEST_SUITE(MaxPool2D_I8, MaxPool2D_I8_Types); TYPED_TEST_SUITE(MaxPool2D_I8, MaxPool2D_I8_Types);
TYPED_TEST_SUITE(MaxPool2D_F8, MaxPool2D_F8_Types); TYPED_TEST_SUITE(MaxPool2D_F8, MaxPool2D_F8_Types);
TYPED_TEST(MaxPool2D_F32, MaxPool2D_I8_Test) { this->Run(); } TYPED_TEST(MaxPool2D_F32, MaxPool2D_F32_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_F16, MaxPool2D_F16_Test) { this->Run(); } TYPED_TEST(MaxPool2D_F16, MaxPool2D_F16_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_BF16, MaxPool2D_BF16_Test) { this->Run(); } TYPED_TEST(MaxPool2D_BF16, MaxPool2D_BF16_Test) { this->Run(); }
TYPED_TEST(MaxPool2D_I8, MaxPool2D_I8_Test) { this->Run(); } TYPED_TEST(MaxPool2D_I8, MaxPool2D_I8_Test) { this->Run(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment