Commit c6ec6638 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into auto_contig_fix

parents b42c7b41 a6d1540f
......@@ -200,7 +200,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror";
auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name());
auto cos = compile_hip_src(srcs, options.params, get_device_name());
if(cos.size() != 1)
MIGRAPHX_THROW("No code object");
return code_object_op{value::binary{cos.front()},
......
......@@ -43,24 +43,32 @@ template <index_int N,
__device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, Output output)
{
using type = decltype(input(deduce_for_stride(fs)));
MIGRAPHX_DEVICE_SHARED type buffer[N];
MIGRAPHX_DEVICE_SHARED type buffer[2][N];
type x = init;
fs([&](auto i) {
index_int iout = 0;
index_int iin = 1;
if(idx.local == 0)
buffer[idx.local] = op(input(i), x);
buffer[iout][idx.local] = op(input(i), x);
else
buffer[idx.local] = input(i);
buffer[iout][idx.local] = input(i);
__syncthreads();
for(index_int s = 1; s < idx.nlocal(); s *= 2)
{
if(idx.local + s < idx.nlocal())
iout = 1 - iout;
iin = 1 - iin;
if(idx.local >= s)
{
buffer[idx.local + s] = op(buffer[idx.local], buffer[idx.local + s]);
buffer[iout][idx.local] = op(buffer[iin][idx.local], buffer[iin][idx.local - s]);
}
else
{
buffer[iout][idx.local] = buffer[iin][idx.local];
}
__syncthreads();
}
x = buffer[idx.nlocal() - 1];
output(i, buffer[idx.local]);
x = buffer[iout][idx.nlocal() - 1];
output(i, buffer[iout][idx.local]);
});
}
......
......@@ -146,7 +146,7 @@ __device__ __host__ T to_hip_type(T x)
// Hip doens't support __fp16
inline __device__ __host__ float to_hip_type(gpu_half x) { return x; }
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
#define MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
......@@ -157,9 +157,9 @@ inline __device__ __host__ float to_hip_type(gpu_half x) { return x; }
{ \
};
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16)
} // namespace device
} // namespace gpu
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/time_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace driver {
struct precompile_op : action<precompile_op>
{
static program create_preop_program(const operation& preop, std::vector<shape> inputs)
{
program p;
auto* mm = p.get_main_module();
std::vector<instruction_ref> args;
inputs.pop_back();
transform(inputs, range(inputs.size()), std::back_inserter(args), [&](auto input, auto i) {
return mm->add_parameter("x" + std::to_string(i), input);
});
mm->add_instruction(preop, args);
return p;
}
static operation get_code_object(const program& p)
{
MIGRAPHX_TIDY_CONST auto* mm = p.get_main_module();
auto it = std::find_if(mm->begin(), mm->end(), [](const auto& ins) {
return (ins.name() == "gpu::code_object");
});
if(it == mm->end())
MIGRAPHX_THROW("Failed to create code object");
return it->get_operator();
}
static void apply(const parser& p, const value& v)
{
context ctx;
auto inputs = p.parse_shapes(v.at("inputs"));
auto name = v.at("name").to<std::string>();
auto preop = make_op(name);
if(v.contains("fields"))
preop.from_value(v.at("fields"));
bool exhaustive = v.get("exhaustive", false);
auto prog = create_preop_program(preop, inputs);
run_passes(prog, {lowering{}, compile_ops{&ctx, exhaustive}});
auto op = get_code_object(prog);
auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << preop << ": " << t << "ms" << std::endl;
}
};
} // namespace driver
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -46,6 +46,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::fp8e4m3fnuz_type:
case shape::tuple_type:
case shape::bool_type:
case shape::uint16_type:
......
......@@ -58,10 +58,10 @@ struct hiprtc_src_file
MIGRAPHX_GPU_EXPORT bool hip_has_flags(const std::vector<std::string>& flags);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc(
std::vector<hiprtc_src_file> srcs, std::string params, const std::string& arch);
std::vector<hiprtc_src_file> srcs, const std::string& params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src(
const std::vector<src_file>& srcs, const std::string& params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::string enum_params(std::size_t count, std::string param);
......
......@@ -211,6 +211,12 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
ss << op.mode;
MIGRAPHX_THROW(ss.str());
}
if(not std::all_of(
op.dilations.cbegin(), op.dilations.cend(), [](std::size_t d) { return d == 1; }))
{
MIGRAPHX_THROW("Unsupported dilations for pooling: [" + to_string_range(op.dilations) +
"]");
}
auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor);
int kdims = op.kdims();
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_JIT_SCATTER_HPP
#define MIGRAPHX_GUARD_JIT_SCATTER_HPP
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <typename Derived>
struct scatter_compiler : compiler<Derived>
{
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
const auto inputs =
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()});
hip_compile_options options;
options.set_launch_params(op.to_value(), compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = derived().get_kernel_name(op);
options.virtual_inputs = inputs;
// The compiler protests the inequality comparison in assign_mul when pertaining to floating
// point, despite it making sense in the context. Thus the warning removal.
options.params += "-Wno-float-equal";
const auto src = derived().make_interpolated_string(op);
return prepend_copy_data_to_output(compile_hip_code_object(src, options));
}
compiler_replace prepend_copy_data_to_output(const operation& co) const
{
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
std::string get_kernel_name(const operation& op) const { return op.name() + "_kernel"; }
const Derived& derived() const { return static_cast<const Derived&>(*this); }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -21,11 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include "scatter.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void*
)__migraphx__";
struct scatternd_compiler : compiler<scatternd_compiler>
struct scatternd_compiler : scatter_compiler<scatternd_compiler>
{
std::vector<std::string> names() const
{
return {"scatternd_none", "scatternd_add", "scatternd_mul"};
return {
"scatternd_none", "scatternd_add", "scatternd_mul", "scatternd_min", "scatternd_max"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
std::string make_interpolated_string(const operation& op) const
{
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"});
auto src = interpolate_string(scatternd_kernel, {{"reduction", reduction}});
return compile_hip_code_object(src, options);
const auto reduction = op.name().substr(std::char_traits<char>::length("scatternd_"));
return interpolate_string(scatternd_kernel, {{"reduction", "assign_" + reduction}});
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10);
return insert(compile_op(
ctx,
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}}));
}
compiler_replace insert(const operation& co) const
{
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
std::string get_kernel_name(const operation&) const { return "scatternd_kernel"; }
};
} // namespace gpu
......
......@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices =
size_t num_slices =
accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{});
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(),
1,
op::product{});
const std::size_t num_batches =
const size_t num_batches =
accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{});
const std::size_t data_batch_stride =
const size_t data_batch_stride =
accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{});
const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch;
const size_t j = i / slice_size;
const size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx)
size_t relative_slice_offset = 0;
for(size_t idx = 0; idx < num_slice_dims; ++idx)
{
int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx;
const size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx];
MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim));
if(index < 0)
index += input_dim;
std::size_t size_from_slice_dims =
size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size,
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#define MIGRAPHX_GUARD_KERNELS_SCATTER_REDUCTION_MODES_HPP
#include <migraphx/kernels/types.hpp>
namespace migraphx {
struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
atomicAdd(&x, y);
}
};
struct assign_mul
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
T old = x;
T assumed;
do
{
assumed = old;
old = atomicCAS(&x, assumed, assumed * y);
} while(assumed != old);
}
};
struct assign_max
{
template <typename T, typename U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
atomicMax(&x, y);
}
};
struct assign_min
{
template <typename T, typename U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
atomicMin(&x, y);
}
};
} // namespace migraphx
#endif
......@@ -26,36 +26,10 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/scatter_reduction_modes.hpp>
namespace migraphx {
struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x += y;
}
};
struct assign_mul
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x *= y;
}
};
template <class T, class U, class V, class F>
__device__ void scatternd(const T& indices_t, const U& updates_t, const V& output_t, F f)
{
......
......@@ -28,7 +28,9 @@
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
#include <migraphx/gpu/ck.hpp>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -122,6 +124,8 @@ struct find_add_layernorm
}
};
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
struct pre_gemm_softmax_gemm : gemm_softmax_gemm
{
std::string name() const { return "gpu::pre_gemm_softmax_gemm"; }
......@@ -175,6 +179,8 @@ struct find_gemm_softmax_gemm
}
};
#endif
} // namespace
void prefuse_ops::apply(module_pass_manager& mpm) const
......@@ -182,8 +188,10 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{});
#ifdef MIHRAPHX_USE_COMPOSABLEKERNEL
if(enabled(MIGRAPHX_ENABLE_CK{}))
match::find_matches(mpm, find_gemm_softmax_gemm{});
#endif
}
} // namespace gpu
......
......@@ -33,8 +33,9 @@ rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION})
find_path(BLAZE_INCLUDE blaze/Blaze.h)
rocm_clang_tidy_check(migraphx_ref)
target_link_libraries(migraphx_ref PRIVATE Threads::Threads)
target_link_libraries(migraphx_ref PUBLIC migraphx)
target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE})
target_include_directories(migraphx_ref SYSTEM PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)
migraphx_generate_export_header(migraphx_ref)
......
......@@ -38,7 +38,11 @@ protobuf_generate_cpp(
)
add_library(tf-proto STATIC ${PROTO_SRCS})
target_include_directories(tf-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR})
target_compile_options(tf-proto PRIVATE -w)
if(MSVC)
target_compile_options(tf-proto PRIVATE /w)
else()
target_compile_options(tf-proto PRIVATE -w)
endif()
target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
......@@ -49,7 +53,10 @@ target_include_directories(migraphx_tf PRIVATE include)
set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf)
rocm_set_soversion(migraphx_tf ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_tf)
target_link_libraries(migraphx_tf PRIVATE tf-proto "-Wl,--exclude-libs,ALL")
target_link_libraries(migraphx_tf PRIVATE tf-proto)
if(NOT WIN32)
target_link_libraries(migraphx_tf PRIVATE "-Wl,--exclude-libs,ALL")
endif()
target_link_libraries(migraphx_tf PUBLIC migraphx)
rocm_install_targets(
......
......@@ -31,8 +31,18 @@
#include <sstream>
#include <iostream>
#include <string>
#include <sys/types.h>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#undef getpid
// cppcheck-suppress [definePrefix, defineUpperCase]
#define getpid _getpid
#else
#include <unistd.h>
#include <sys/types.h>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -150,6 +150,7 @@ function(test_headers PREFIX)
list(REMOVE_ITEM HEADERS
${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/ck.hpp)
endif()
list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/float8_impl.hpp)
foreach(HEADER ${HEADERS})
file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER})
string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME)
......
......@@ -22,6 +22,7 @@
* THE SOFTWARE.
*/
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include "test.hpp"
......@@ -53,7 +54,7 @@ auto test_float_equal(T x, U y)
template <class T, class U>
void test_equality()
{
auto x1 = T(0.1);
auto x1 = T(0.125);
auto x2 = U(0.0);
auto x3 = U(1.0);
EXPECT(test_float_equal(x1, x1));
......@@ -71,8 +72,12 @@ void test_equality()
TEST_CASE_REGISTER(test_equality<double, float>);
TEST_CASE_REGISTER(test_equality<double, int>);
TEST_CASE_REGISTER(test_equality<double, migraphx::half>);
TEST_CASE_REGISTER(test_equality<double, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<float, int>);
TEST_CASE_REGISTER(test_equality<float, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<migraphx::half, int>);
TEST_CASE_REGISTER(test_equality<migraphx::half, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<migraphx::fp8::fp8e4m3fnuz, int>);
template <class T, class U>
void test_limits()
......@@ -110,8 +115,13 @@ void test_limits()
TEST_CASE_REGISTER(test_limits<double, float>);
TEST_CASE_REGISTER(test_limits<double, int>);
TEST_CASE_REGISTER(test_limits<double, migraphx::half>);
TEST_CASE_REGISTER(test_limits<double, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<float, int>);
TEST_CASE_REGISTER(test_limits<float, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<int, migraphx::half>);
TEST_CASE_REGISTER(test_limits<int, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<migraphx::fp8::fp8e4m3fnuz, migraphx::half>);
#ifndef _WIN32
// On Windows, types int and long have the same min and max values.
TEST_CASE_REGISTER(test_limits<long, int>);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
float fp8e4m3fn_to_fp32_value(uint8_t input)
{
constexpr std::array<float, 256> e4m3fnuz_lut = {
0.0, 0.001953125, 0.00390625, 0.005859375,
0.0078125, 0.009765625, 0.01171875, 0.013671875,
0.015625, 0.017578125, 0.01953125, 0.021484375,
0.0234375, 0.025390625, 0.02734375, 0.029296875,
0.03125, 0.03515625, 0.0390625, 0.04296875,
0.046875, 0.05078125, 0.0546875, 0.05859375,
0.0625, 0.0703125, 0.078125, 0.0859375,
0.09375, 0.1015625, 0.109375, 0.1171875,
0.125, 0.140625, 0.15625, 0.171875,
0.1875, 0.203125, 0.21875, 0.234375,
0.25, 0.28125, 0.3125, 0.34375,
0.375, 0.40625, 0.4375, 0.46875,
0.5, 0.5625, 0.625, 0.6875,
0.75, 0.8125, 0.875, 0.9375,
1.0, 1.125, 1.25, 1.375,
1.5, 1.625, 1.75, 1.875,
2.0, 2.25, 2.5, 2.75,
3.0, 3.25, 3.5, 3.75,
4.0, 4.5, 5.0, 5.5,
6.0, 6.5, 7.0, 7.5,
8.0, 9.0, 10.0, 11.0,
12.0, 13.0, 14.0, 15.0,
16.0, 18.0, 20.0, 22.0,
24.0, 26.0, 28.0, 30.0,
32.0, 36.0, 40.0, 44.0,
48.0, 52.0, 56.0, 60.0,
64.0, 72.0, 80.0, 88.0,
96.0, 104.0, 112.0, 120.0,
128.0, 144.0, 160.0, 176.0,
192.0, 208.0, 224.0, 240.0,
256.0, 288.0, 320.0, 352.0,
384.0, 416.0, 448.0, std::numeric_limits<float>::quiet_NaN(),
-0.0, -0.001953125, -0.00390625, -0.005859375,
-0.0078125, -0.009765625, -0.01171875, -0.013671875,
-0.015625, -0.017578125, -0.01953125, -0.021484375,
-0.0234375, -0.025390625, -0.02734375, -0.029296875,
-0.03125, -0.03515625, -0.0390625, -0.04296875,
-0.046875, -0.05078125, -0.0546875, -0.05859375,
-0.0625, -0.0703125, -0.078125, -0.0859375,
-0.09375, -0.1015625, -0.109375, -0.1171875,
-0.125, -0.140625, -0.15625, -0.171875,
-0.1875, -0.203125, -0.21875, -0.234375,
-0.25, -0.28125, -0.3125, -0.34375,
-0.375, -0.40625, -0.4375, -0.46875,
-0.5, -0.5625, -0.625, -0.6875,
-0.75, -0.8125, -0.875, -0.9375,
-1.0, -1.125, -1.25, -1.375,
-1.5, -1.625, -1.75, -1.875,
-2.0, -2.25, -2.5, -2.75,
-3.0, -3.25, -3.5, -3.75,
-4.0, -4.5, -5.0, -5.5,
-6.0, -6.5, -7.0, -7.5,
-8.0, -9.0, -10.0, -11.0,
-12.0, -13.0, -14.0, -15.0,
-16.0, -18.0, -20.0, -22.0,
-24.0, -26.0, -28.0, -30.0,
-32.0, -36.0, -40.0, -44.0,
-48.0, -52.0, -56.0, -60.0,
-64.0, -72.0, -80.0, -88.0,
-96.0, -104.0, -112.0, -120.0,
-128.0, -144.0, -160.0, -176.0,
-192.0, -208.0, -224.0, -240.0,
-256.0, -288.0, -320.0, -352.0,
-384.0, -416.0, -448.0, std::numeric_limits<float>::quiet_NaN(),
};
return e4m3fnuz_lut[input];
}
TEST_CASE(test_fp8_cast_to_float)
{
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx::fp8::fp8e4m3fn fp8_val(bit_val, migraphx::fp8::fp8e4m3fn::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fn_to_fp32_value(bit_val)))
{
return true;
}
return migraphx::float_equal(float(fp8_val), fp8e4m3fn_to_fp32_value(bit_val));
})});
}
TEST_CASE(test_fp8_cast_from_float)
{
std::unordered_map<float, uint8_t> test_vals = {
{{512, 0x7e}, {-512, 0xfe}, {448, 0x7e}, {-448, 0xfe},
{256, 0x78}, {-256, 0xf8}, {240, 0x77}, {-240, 0xf7},
{1e-07, 0x0}, {1e+07, 0x7e}, {1, 0x38}, {-1, 0xb8},
{0.1, 0x1d}, {0.11, 0x1e}, {0.111, 0x1e}, {0.1111, 0x1e},
{-0.1, 0x9d}, {-0.11, 0x9e}, {-0.111, 0x9e}, {-0.1111, 0x9e},
{0.2, 0x25}, {2, 0x40}, {20, 0x5a}, {200, 0x74},
{-0.2, 0xa5}, {-2, 0xc0}, {-20, 0xda}, {-200, 0xf4},
{0.5, 0x30}, {-0.5, 0xb0}, {1.17549e-38, 0x0}, {1.4013e-45, 0x0},
{0.0078125, 0x4}, {-0.0078125, 0x84}, {0.000976562, 0x0}, {-0.000976562, 0x80},
{0.000488281, 0x0}, {-0.000488281, 0x80}}};
EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) {
return migraphx::float_equal(
migraphx::fp8::fp8e4m3fn(sample.first),
migraphx::fp8::fp8e4m3fn(sample.second, migraphx::fp8::fp8e4m3fn::from_bits()));
})});
}
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx::fp8::fp8e4m3fn fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
migraphx::fp8::fp8e4m3fn fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero is preserved for fp8e4m3fn
EXPECT(migraphx::float_equal(nzero, float(fp8_nzero)));
}
TEST_CASE(test_pos_zero_eq_neg_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx::fp8::fp8e5m2 fp8_nzero(nzero);
migraphx::fp8::fp8e5m2 fp8_pzero(pzero);
EXPECT(fp8_nzero == fp8_pzero);
}
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx::fp8::fp8e4m3fn fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx::fp8::fp8e4m3fn>::quiet_NaN();
migraphx::fp8::fp8e4m3fn fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fn::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_1)
{
float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fn, it gets clipped to max()
migraphx::fp8::fp8e4m3fn fp8_max(finf);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_infinity_2)
{
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fn, it gets clipped to lowest
migraphx::fp8::fp8e4m3fn fp8_lowest(finf);
EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest()});
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx::fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max();
migraphx::fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx::fp8::fp8e4m3fn fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest();
migraphx::fp8::fp8e4m3fn fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max()));
}
TEST_CASE(test_isfinite)
{
EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fn(0.0)));
EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fn(-0.0)));
EXPECT(not std::isfinite(
migraphx::fp8::fp8e4m3fn(std::numeric_limits<migraphx::fp8::fp8e4m3fn>::quiet_NaN())));
}
TEST_CASE(test_no_infinity)
{
EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e4m3fn>::has_infinity});
}
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e4m3fn(-1.0);
auto b = migraphx::fp8::fp8e4m3fn(1.0);
auto c = migraphx::fp8::fp8e4m3fn(0.0);
auto d = migraphx::fp8::fp8e4m3fn(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e4m3fn(10.0);
auto f = migraphx::fp8::fp8e4m3fn(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool{f <= e});
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
TEST_CASE(test_fabs)
{
auto a = migraphx::fp8::fp8e4m3fn(-1.0);
auto b = migraphx::fp8::fp8e4m3fn(1.0);
EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a)));
}
TEST_CASE(test_stream_op)
{
auto a = migraphx::fp8::fp8e4m3fn(-1.0);
std::stringstream ss;
ss << a;
EXPECT(std::string("-1") == ss.str());
ss = std::stringstream();
auto b = std::numeric_limits<migraphx::fp8::fp8e4m3fn>::quiet_NaN();
ss << b;
EXPECT(std::string("nan") == ss.str());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
float fp8e4m3fnuz_to_fp32_value(uint8_t input)
{
constexpr std::array<float, 256> e4m3fnuz_lut = {
0.0f, 0.0009765625f, 0.001953125f,
0.0029296875f, 0.00390625f, 0.0048828125f,
0.005859375f, 0.0068359375f, 0.0078125f,
0.0087890625f, 0.009765625f, 0.0107421875f,
0.01171875f, 0.0126953125f, 0.013671875f,
0.0146484375f, 0.015625f, 0.017578125f,
0.01953125f, 0.021484375f, 0.0234375f,
0.025390625f, 0.02734375f, 0.029296875f,
0.03125f, 0.03515625f, 0.0390625f,
0.04296875f, 0.046875f, 0.05078125f,
0.0546875f, 0.05859375f, 0.0625f,
0.0703125f, 0.078125f, 0.0859375f,
0.09375f, 0.1015625f, 0.109375f,
0.1171875f, 0.125f, 0.140625f,
0.15625f, 0.171875f, 0.1875f,
0.203125f, 0.21875f, 0.234375f,
0.25f, 0.28125f, 0.3125f,
0.34375f, 0.375f, 0.40625f,
0.4375f, 0.46875f, 0.5f,
0.5625f, 0.625f, 0.6875f,
0.75f, 0.8125f, 0.875f,
0.9375f, 1.0f, 1.125f,
1.25f, 1.375f, 1.5f,
1.625f, 1.75f, 1.875f,
2.0f, 2.25f, 2.5f,
2.75f, 3.0f, 3.25f,
3.5f, 3.75f, 4.0f,
4.5f, 5.0f, 5.5f,
6.0f, 6.5f, 7.0f,
7.5f, 8.0f, 9.0f,
10.0f, 11.0f, 12.0f,
13.0f, 14.0f, 15.0f,
16.0f, 18.0f, 20.0f,
22.0f, 24.0f, 26.0f,
28.0f, 30.0f, 32.0f,
36.0f, 40.0f, 44.0f,
48.0f, 52.0f, 56.0f,
60.0f, 64.0f, 72.0f,
80.0f, 88.0f, 96.0f,
104.0f, 112.0f, 120.0f,
128.0f, 144.0f, 160.0f,
176.0f, 192.0f, 208.0f,
224.0f, 240.0f, std::numeric_limits<float>::quiet_NaN(),
-0.0009765625f, -0.001953125f, -0.0029296875f,
-0.00390625f, -0.0048828125f, -0.005859375f,
-0.0068359375f, -0.0078125f, -0.0087890625f,
-0.009765625f, -0.0107421875f, -0.01171875f,
-0.0126953125f, -0.013671875f, -0.0146484375f,
-0.015625f, -0.017578125f, -0.01953125f,
-0.021484375f, -0.0234375f, -0.025390625f,
-0.02734375f, -0.029296875f, -0.03125f,
-0.03515625f, -0.0390625f, -0.04296875f,
-0.046875f, -0.05078125f, -0.0546875f,
-0.05859375f, -0.0625f, -0.0703125f,
-0.078125f, -0.0859375f, -0.09375f,
-0.1015625f, -0.109375f, -0.1171875f,
-0.125f, -0.140625f, -0.15625f,
-0.171875f, -0.1875f, -0.203125f,
-0.21875f, -0.234375f, -0.25f,
-0.28125f, -0.3125f, -0.34375f,
-0.375f, -0.40625f, -0.4375f,
-0.46875f, -0.5f, -0.5625f,
-0.625f, -0.6875f, -0.75f,
-0.8125f, -0.875f, -0.9375f,
-1.0f, -1.125f, -1.25f,
-1.375f, -1.5f, -1.625f,
-1.75f, -1.875f, -2.0f,
-2.25f, -2.5f, -2.75f,
-3.0f, -3.25f, -3.5f,
-3.75f, -4.0f, -4.5f,
-5.0f, -5.5f, -6.0f,
-6.5f, -7.0f, -7.5f,
-8.0f, -9.0f, -10.0f,
-11.0f, -12.0f, -13.0f,
-14.0f, -15.0f, -16.0f,
-18.0f, -20.0f, -22.0f,
-24.0f, -26.0f, -28.0f,
-30.0f, -32.0f, -36.0f,
-40.0f, -44.0f, -48.0f,
-52.0f, -56.0f, -60.0f,
-64.0f, -72.0f, -80.0f,
-88.0f, -96.0f, -104.0f,
-112.0f, -120.0f, -128.0f,
-144.0f, -160.0f, -176.0f,
-192.0f, -208.0f, -224.0f,
-240.0f,
};
return e4m3fnuz_lut[input];
}
TEST_CASE(test_fp8_cast_to_float)
{
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx::fp8::fp8e4m3fnuz fp8_val(bit_val, migraphx::fp8::fp8e4m3fnuz::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fnuz_to_fp32_value(bit_val)))
{
return true;
}
return migraphx::float_equal(float(fp8_val), fp8e4m3fnuz_to_fp32_value(bit_val));
})});
}
TEST_CASE(test_fp8_cast_from_float)
{
std::unordered_map<float, uint8_t> test_vals = {{256, 0x7f}, {-256, 0xff},
{240, 0x7f}, {-240, 0xff},
{1e-07, 0x0}, {1e+07, 0x7f},
{1, 0x40}, {-1, 0xc0},
{0.1, 0x25}, {0.11, 0x26},
{0.111, 0x26}, {0.1111, 0x26},
{-0.1, 0xa5}, {-0.11, 0xa6},
{-0.111, 0xa6}, {-0.1111, 0xa6},
{0.2, 0x2d}, {2, 0x48},
{20, 0x62}, {200, 0x7c},
{-0.2, 0xad}, {-2, 0xc8},
{-20, 0xe2}, {-200, 0xfc},
{0.5, 0x38}, {-0.5, 0xb8},
{1.17549e-38, 0x0}, {1.4013e-45, 0x0},
{0.00390625, 0x4}, {-0.00390625, 0x84},
{0.00195312, 0x2}, {-0.00195312, 0x82},
{0.000976562, 0x1}, {-0.000976562, 0x81},
{0.000488281, 0x0}, {-0.000488281, 0x0}};
EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) {
return migraphx::float_equal(
migraphx::fp8::fp8e4m3fnuz(sample.first),
migraphx::fp8::fp8e4m3fnuz(sample.second, migraphx::fp8::fp8e4m3fnuz::from_bits()));
})});
}
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx::fp8::fp8e4m3fnuz fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx::fp8::fp8e4m3fnuz fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero gets converted to positive zero
EXPECT(migraphx::float_equal(pzero, float(fp8_nzero)));
}
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::quiet_NaN();
migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fnuz::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_1)
{
float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fnuz it gets clipped to Nans
migraphx::fp8::fp8e4m3fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_2)
{
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fnuz it gets clipped to NaNs
migraphx::fp8::fp8e4m3fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx::fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max();
migraphx::fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx::fp8::fp8e4m3fnuz fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest();
migraphx::fp8::fp8e4m3fnuz fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max()));
}
TEST_CASE(test_isfinite)
{
EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fnuz(0.0)));
EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fnuz(-0.0)));
EXPECT(not std::isfinite(
migraphx::fp8::fp8e4m3fnuz(std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::quiet_NaN())));
}
TEST_CASE(test_no_infinity)
{
EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::has_infinity});
}
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e4m3fnuz(-1.0);
auto b = migraphx::fp8::fp8e4m3fnuz(1.0);
auto c = migraphx::fp8::fp8e4m3fnuz(0.0);
auto d = migraphx::fp8::fp8e4m3fnuz(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e4m3fnuz(10.0);
auto f = migraphx::fp8::fp8e4m3fnuz(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool{f <= e});
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
TEST_CASE(test_fabs)
{
auto a = migraphx::fp8::fp8e4m3fnuz(-1.0);
auto b = migraphx::fp8::fp8e4m3fnuz(1.0);
EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a)));
}
TEST_CASE(test_stream_op)
{
auto a = migraphx::fp8::fp8e4m3fnuz(-1.0);
std::stringstream ss;
ss << a;
EXPECT(std::string("-1") == ss.str());
ss = std::stringstream();
auto b = std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::quiet_NaN();
ss << b;
EXPECT(std::string("nan") == ss.str());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment