Unverified Commit 829e0eb3 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #138 from ROCm/merge_from_public

Merge from public
parents 3d61f89a 1111bc69
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
F8,
F8,
ck::Tuple<>,
F32,
PassThrough,
PassThrough,
CombConvScale,
F8,
F8>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwdDefault,
CombConvScale>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwd1x1P0,
CombConvScale>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
CombConvScale>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
# ONLY XDL_KERNELS # ONLY XDL_KERNELS
set(GROUPED_CONV3D_FWD_CONVSCALE_RELU set(GROUPED_CONV3D_FWD_CONVSCALE_RELU
xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp)
add_instance_library(device_grouped_conv3d_fwd_convscale_relu_instance ${GROUPED_CONV3D_FWD_CONVSCALE_RELU}) add_instance_library(device_grouped_conv3d_fwd_convscale_relu_instance ${GROUPED_CONV3D_FWD_CONVSCALE_RELU})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
F8,
F8,
ck::Tuple<>,
F32,
PassThrough,
PassThrough,
CombConvScaleRelu,
F8,
F8>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwdDefault,
CombConvScaleRelu>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwd1x1P0,
CombConvScaleRelu>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
CombConvScaleRelu>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -3,16 +3,13 @@ ...@@ -3,16 +3,13 @@
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu;
void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances( void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
...@@ -57,55 +54,6 @@ void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_in ...@@ -57,55 +54,6 @@ void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_in
ConvFwd1x1S1P0, ConvFwd1x1S1P0,
ConvScaleRelu>{}); ConvScaleRelu>{});
} }
namespace ew = ck::tensor_operation::element_wise;
using CombConvScaleRelu = ew::UnaryCombinedOp<ew::Scale, ew::Scale, ew::Relu>;
void add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
F8,
F8,
ck::Tuple<>,
F32,
PassThrough,
PassThrough,
CombConvScaleRelu,
F8,
F8>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwdDefault,
CombConvScaleRelu>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwd1x1P0,
CombConvScaleRelu>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
CombConvScaleRelu>{});
}
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -250,18 +250,14 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ...@@ -250,18 +250,14 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
{ {
LogRangeAsType<float>(std::cout << "output : ", output.mData, ",") LogRangeAsType<float>(std::cout << "output : ", output.mData, ",")
<< std::endl; << std::endl;
;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "weight (device): ", weight_device_result.mData, ",") std::cout << "weight (device): ", weight_device_result.mData, ",")
<< std::endl; << std::endl;
;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "weight (host): ", weight_host_result.mData, ",") std::cout << "weight (host): ", weight_host_result.mData, ",")
<< std::endl; << std::endl;
;
LogRangeAsType<float>(std::cout << "input: ", input.mData, ",") LogRangeAsType<float>(std::cout << "input: ", input.mData, ",")
<< std::endl; << std::endl;
;
} }
} }
} }
......
...@@ -26,7 +26,7 @@ def run_ck_profiler_cmd(cmd): ...@@ -26,7 +26,7 @@ def run_ck_profiler_cmd(cmd):
def parse_data_type(args): def parse_data_type(args):
if args.data_type == "fp32": if args.data_type == "fp32":
if args.ck_profier_op == "grouped_conv_bwd_weight" or \ if args.ck_profier_op == "grouped_conv_bwd_weight" or \
args.ck_profier_op == "grouped_conv_bwd_weight" or \ args.ck_profier_op == "grouped_conv_bwd_data" or \
args.ck_profier_op == "grouped_conv_fwd": args.ck_profier_op == "grouped_conv_fwd":
args.data_type = 0 args.data_type = 0
if args.data_type == "fp16": if args.data_type == "fp16":
......
if (GPU_TARGETS)
if (GPU_TARGETS MATCHES "gfx10" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
add_definitions(-DCK_SKIP_FLAKY_F8_TEST)
set(CK_SKIP_FLAKY_F8_TEST "ON")
endif()
else()
add_definitions(-DCK_SKIP_FLAKY_F8_TEST)
set(CK_SKIP_FLAKY_F8_TEST "ON")
endif()
if (USE_BITINT_EXTENSION_INT4) if (USE_BITINT_EXTENSION_INT4)
add_gtest_executable(test_int4 test_int4.cpp) add_gtest_executable(test_int4 test_int4.cpp)
if(result EQUAL 0) if(result EQUAL 0)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp" #include "ck/utility/type_convert.hpp"
using ck::bf8_t; using ck::bf8_t;
using ck::f8_convert_rne;
using ck::f8_convert_sr; using ck::f8_convert_sr;
using ck::half_t; using ck::half_t;
using ck::type_convert; using ck::type_convert;
...@@ -24,33 +25,36 @@ TEST(BF8, ConvertFP32Nearest) ...@@ -24,33 +25,36 @@ TEST(BF8, ConvertFP32Nearest)
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-6; float abs_tol = 1e-6;
// convert 0 float to bf8 and back, check if holds // convert 0 float to bf8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(type_convert<bf8_t>(0.0f)), abs_tol); ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<bf8_t>(0.0f)), abs_tol);
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to bf8 and back, check if holds // convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(), ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(type_convert<bf8_t>(std::numeric_limits<float>::min())), type_convert<float>(f8_convert_rne<bf8_t>(std::numeric_limits<float>::min())),
abs_tol); abs_tol);
#endif
// convert maximal bf8_t to float and check if equal to 57344.0 // convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR(57344.0f, type_convert<float>(type_convert<bf8_t>(57344.0f)), abs_tol); ASSERT_NEAR(57344.0f, type_convert<float>(f8_convert_rne<bf8_t>(57344.0f)), abs_tol);
// convert maximal float to bf8 and back, check if clipped to 57344.0 // convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(57344.0f, ASSERT_NEAR(57344.0f,
type_convert<float>(type_convert<bf8_t>(std::numeric_limits<float>::max())), type_convert<float>(f8_convert_rne<bf8_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to bf8_t and check if it is qNan // convert inf float to bf8_t and check if it is qNan
ASSERT_NEAR(type_convert<bf8_t>(0x80), ASSERT_NEAR(type_convert<bf8_t>(0x80),
type_convert<bf8_t>(std::numeric_limits<float>::infinity()), f8_convert_rne<bf8_t>(std::numeric_limits<float>::infinity()),
abs_tol); abs_tol);
// positive norm float value to bf8 and back, check if holds // positive norm float value to bf8 and back, check if holds
float pos_float = 0.0000762939f; float pos_float = 0.0000762939f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<bf8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_t>(pos_float)), abs_tol);
// negative norm float value to bf8 and back, check if holds // negative norm float value to bf8 and back, check if holds
float neg_float = -0.0000610351f; float neg_float = -0.0000610351f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<bf8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_t>(neg_float)), abs_tol);
// positive subnorm float value to bf8 and back, check if holds // positive subnorm float value to bf8 and back, check if holds
pos_float = 0.0000305175f; pos_float = 0.0000305175f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<bf8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<bf8_t>(pos_float)), abs_tol);
// negative subnorm float value to bf8 and back, check if holds // negative subnorm float value to bf8 and back, check if holds
neg_float = -0.0000152587f; neg_float = -0.0000152587f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<bf8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<bf8_t>(neg_float)), abs_tol);
} }
TEST(BF8, ConvertFP32Stochastic) TEST(BF8, ConvertFP32Stochastic)
...@@ -92,34 +96,34 @@ TEST(BF8, ConvertFP16Nearest) ...@@ -92,34 +96,34 @@ TEST(BF8, ConvertFP16Nearest)
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-3; float abs_tol = 1e-3;
// convert 0 fp16 to bf8 and back, check if holds // convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(type_convert<bf8_t>(half_t{0.0})), abs_tol); ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<bf8_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to bf8 and back, check if holds // convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(), ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(type_convert<bf8_t>(ck::NumericLimits<half_t>::Min())), type_convert<half_t>(f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::Min())),
abs_tol); abs_tol);
// convert maximal bf8_t to fp16 and check if equal to 57344.0 // convert maximal bf8_t to fp16 and check if equal to 57344.0
ASSERT_NEAR( ASSERT_NEAR(
half_t{57344.0}, type_convert<half_t>(type_convert<bf8_t>(half_t{57344.0})), abs_tol); half_t{57344.0}, type_convert<half_t>(f8_convert_rne<bf8_t>(half_t{57344.0})), abs_tol);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0 // convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(half_t{57344.0}, ASSERT_NEAR(half_t{57344.0},
type_convert<half_t>(type_convert<bf8_t>(ck::NumericLimits<half_t>::Max())), type_convert<half_t>(f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN // convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<bf8_t>(0x80), ASSERT_NEAR(type_convert<bf8_t>(0x80),
type_convert<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()), f8_convert_rne<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol); abs_tol);
// positive norm fp16 value to bf8 and back, check if holds // positive norm fp16 value to bf8 and back, check if holds
half_t pos_half = half_t{0.0000762939}; half_t pos_half = half_t{0.0000762939};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<bf8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_t>(pos_half)), abs_tol);
// negative norm fp16 value to bf8 and back, check if holds // negative norm fp16 value to bf8 and back, check if holds
half_t neg_half = half_t{-0.0000610351}; half_t neg_half = half_t{-0.0000610351};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<bf8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to bf8 and back, check if holds // positive subnorm fp16 value to bf8 and back, check if holds
pos_half = half_t{0.0000305175}; pos_half = half_t{0.0000305175};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<bf8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<bf8_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to bf8 and back, check if holds // negative subnorm fp16 value to bf8 and back, check if holds
neg_half = half_t{-0.0000152587}; neg_half = half_t{-0.0000152587};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<bf8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<bf8_t>(neg_half)), abs_tol);
} }
TEST(BF8, ConvertFP16Stochastic) TEST(BF8, ConvertFP16Stochastic)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp" #include "ck/utility/type_convert.hpp"
using ck::f8_convert_rne;
using ck::f8_convert_sr; using ck::f8_convert_sr;
using ck::f8_t; using ck::f8_t;
using ck::half_t; using ck::half_t;
...@@ -24,33 +25,36 @@ TEST(FP8, ConvertFP32Nearest) ...@@ -24,33 +25,36 @@ TEST(FP8, ConvertFP32Nearest)
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-6; float abs_tol = 1e-6;
// convert 0 float to fp8 and back, check if holds // convert 0 float to fp8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(type_convert<f8_t>(0.0f)), abs_tol); ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_rne<f8_t>(0.0f)), abs_tol);
// don't run the next test on gfx11 devices
#ifndef CK_SKIP_FLAKY_F8_TEST
// convert minimal float to fp8 and back, check if holds // convert minimal float to fp8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(), ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(type_convert<f8_t>(std::numeric_limits<float>::min())), type_convert<float>(f8_convert_rne<f8_t>(std::numeric_limits<float>::min())),
abs_tol); abs_tol);
#endif
// convert maximal f8_t to float and check if equal to 240.0 // convert maximal f8_t to float and check if equal to 240.0
ASSERT_NEAR(240.0f, type_convert<float>(type_convert<f8_t>(240.0f)), abs_tol); ASSERT_NEAR(240.0f, type_convert<float>(f8_convert_rne<f8_t>(240.0f)), abs_tol);
// convert maximal float to fp8 and back, check if clipped to 240.0 // convert maximal float to fp8 and back, check if clipped to 240.0
ASSERT_NEAR(240.0f, ASSERT_NEAR(240.0f,
type_convert<float>(type_convert<f8_t>(std::numeric_limits<float>::max())), type_convert<float>(f8_convert_rne<f8_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to f8_t and check if it is qNan // convert inf float to f8_t and check if it is qNan
ASSERT_NEAR(type_convert<f8_t>(0x80), ASSERT_NEAR(type_convert<f8_t>(0x80),
type_convert<f8_t>(std::numeric_limits<float>::infinity()), f8_convert_rne<f8_t>(std::numeric_limits<float>::infinity()),
abs_tol); abs_tol);
// positive norm float value to fp8 and back, check if holds // positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f; float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_t>(pos_float)), abs_tol);
// negative norm float value to fp8 and back, check if holds // negative norm float value to fp8 and back, check if holds
float neg_float = -0.015625f; float neg_float = -0.015625f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_t>(neg_float)), abs_tol);
// positive subnorm float value to fp8 and back, check if holds // positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f; pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_rne<f8_t>(pos_float)), abs_tol);
// negative subnorm float value to fp8 and back, check if holds // negative subnorm float value to fp8 and back, check if holds
neg_float = -0.001953125f; neg_float = -0.001953125f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_rne<f8_t>(neg_float)), abs_tol);
} }
TEST(FP8, ConvertFP32Stochastic) TEST(FP8, ConvertFP32Stochastic)
...@@ -92,33 +96,33 @@ TEST(FP8, ConvertFP16Nearest) ...@@ -92,33 +96,33 @@ TEST(FP8, ConvertFP16Nearest)
// fix the tolerance value // fix the tolerance value
float abs_tol = 1e-3; float abs_tol = 1e-3;
// convert 0 fp16 to fp8 and back, check if holds // convert 0 fp16 to fp8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(type_convert<f8_t>(half_t{0.0})), abs_tol); ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_rne<f8_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to fp8 and back, check if holds // convert minimal fp16 to fp8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(), ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(type_convert<f8_t>(ck::NumericLimits<half_t>::Min())), type_convert<half_t>(f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::Min())),
abs_tol); abs_tol);
// convert maximal f8_t to fp16 and check if equal to 240.0 // convert maximal f8_t to fp16 and check if equal to 240.0
ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(type_convert<f8_t>(half_t{240.0})), abs_tol); ASSERT_NEAR(half_t{240.0}, type_convert<half_t>(f8_convert_rne<f8_t>(half_t{240.0})), abs_tol);
// convert maximal fp16 to fp8 and back, check if clipped to 240.0 // convert maximal fp16 to fp8 and back, check if clipped to 240.0
ASSERT_NEAR(half_t{240.0}, ASSERT_NEAR(half_t{240.0},
type_convert<half_t>(type_convert<f8_t>(ck::NumericLimits<half_t>::Max())), type_convert<half_t>(f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN // convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<f8_t>(0x80), ASSERT_NEAR(type_convert<f8_t>(0x80),
type_convert<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), f8_convert_rne<f8_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol); abs_tol);
// positive norm fp16 value to fp8 and back, check if holds // positive norm fp16 value to fp8 and back, check if holds
half_t pos_half = half_t{0.017578125}; half_t pos_half = half_t{0.017578125};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_t>(pos_half)), abs_tol);
// negative norm fp16 value to fp8 and back, check if holds // negative norm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.015625}; half_t neg_half = half_t{-0.015625};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to fp8 and back, check if holds // positive subnorm fp16 value to fp8 and back, check if holds
pos_half = half_t{0.00390625}; pos_half = half_t{0.00390625};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_rne<f8_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to fp8 and back, check if holds // negative subnorm fp16 value to fp8 and back, check if holds
neg_half = half_t{-0.001953125}; neg_half = half_t{-0.001953125};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_rne<f8_t>(neg_half)), abs_tol);
} }
TEST(FP8, ConvertFP16Stochastic) TEST(FP8, ConvertFP16Stochastic)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment