Commit fed23ec7 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents 9d933920 225873aa
...@@ -150,6 +150,7 @@ function(test_headers PREFIX) ...@@ -150,6 +150,7 @@ function(test_headers PREFIX)
list(REMOVE_ITEM HEADERS list(REMOVE_ITEM HEADERS
${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/ck.hpp) ${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/ck.hpp)
endif() endif()
list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/float8_impl.hpp)
foreach(HEADER ${HEADERS}) foreach(HEADER ${HEADERS})
file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER}) file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER})
string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME) string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME)
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -53,7 +54,7 @@ auto test_float_equal(T x, U y) ...@@ -53,7 +54,7 @@ auto test_float_equal(T x, U y)
template <class T, class U> template <class T, class U>
void test_equality() void test_equality()
{ {
auto x1 = T(0.1); auto x1 = T(0.125);
auto x2 = U(0.0); auto x2 = U(0.0);
auto x3 = U(1.0); auto x3 = U(1.0);
EXPECT(test_float_equal(x1, x1)); EXPECT(test_float_equal(x1, x1));
...@@ -71,8 +72,12 @@ void test_equality() ...@@ -71,8 +72,12 @@ void test_equality()
TEST_CASE_REGISTER(test_equality<double, float>); TEST_CASE_REGISTER(test_equality<double, float>);
TEST_CASE_REGISTER(test_equality<double, int>); TEST_CASE_REGISTER(test_equality<double, int>);
TEST_CASE_REGISTER(test_equality<double, migraphx::half>); TEST_CASE_REGISTER(test_equality<double, migraphx::half>);
TEST_CASE_REGISTER(test_equality<double, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<float, int>); TEST_CASE_REGISTER(test_equality<float, int>);
TEST_CASE_REGISTER(test_equality<float, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<migraphx::half, int>); TEST_CASE_REGISTER(test_equality<migraphx::half, int>);
TEST_CASE_REGISTER(test_equality<migraphx::half, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_equality<migraphx::fp8::fp8e4m3fnuz, int>);
template <class T, class U> template <class T, class U>
void test_limits() void test_limits()
...@@ -110,8 +115,13 @@ void test_limits() ...@@ -110,8 +115,13 @@ void test_limits()
TEST_CASE_REGISTER(test_limits<double, float>); TEST_CASE_REGISTER(test_limits<double, float>);
TEST_CASE_REGISTER(test_limits<double, int>); TEST_CASE_REGISTER(test_limits<double, int>);
TEST_CASE_REGISTER(test_limits<double, migraphx::half>); TEST_CASE_REGISTER(test_limits<double, migraphx::half>);
TEST_CASE_REGISTER(test_limits<double, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<float, int>); TEST_CASE_REGISTER(test_limits<float, int>);
TEST_CASE_REGISTER(test_limits<float, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<int, migraphx::half>); TEST_CASE_REGISTER(test_limits<int, migraphx::half>);
TEST_CASE_REGISTER(test_limits<int, migraphx::fp8::fp8e4m3fnuz>);
TEST_CASE_REGISTER(test_limits<migraphx::fp8::fp8e4m3fnuz, migraphx::half>);
#ifndef _WIN32 #ifndef _WIN32
// On Windows, types int and long have the same min and max values. // On Windows, types int and long have the same min and max values.
TEST_CASE_REGISTER(test_limits<long, int>); TEST_CASE_REGISTER(test_limits<long, int>);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
float fp8e4m3fn_to_fp32_value(uint8_t input)
{
constexpr std::array<float, 256> e4m3fnuz_lut = {
0.0, 0.001953125, 0.00390625, 0.005859375,
0.0078125, 0.009765625, 0.01171875, 0.013671875,
0.015625, 0.017578125, 0.01953125, 0.021484375,
0.0234375, 0.025390625, 0.02734375, 0.029296875,
0.03125, 0.03515625, 0.0390625, 0.04296875,
0.046875, 0.05078125, 0.0546875, 0.05859375,
0.0625, 0.0703125, 0.078125, 0.0859375,
0.09375, 0.1015625, 0.109375, 0.1171875,
0.125, 0.140625, 0.15625, 0.171875,
0.1875, 0.203125, 0.21875, 0.234375,
0.25, 0.28125, 0.3125, 0.34375,
0.375, 0.40625, 0.4375, 0.46875,
0.5, 0.5625, 0.625, 0.6875,
0.75, 0.8125, 0.875, 0.9375,
1.0, 1.125, 1.25, 1.375,
1.5, 1.625, 1.75, 1.875,
2.0, 2.25, 2.5, 2.75,
3.0, 3.25, 3.5, 3.75,
4.0, 4.5, 5.0, 5.5,
6.0, 6.5, 7.0, 7.5,
8.0, 9.0, 10.0, 11.0,
12.0, 13.0, 14.0, 15.0,
16.0, 18.0, 20.0, 22.0,
24.0, 26.0, 28.0, 30.0,
32.0, 36.0, 40.0, 44.0,
48.0, 52.0, 56.0, 60.0,
64.0, 72.0, 80.0, 88.0,
96.0, 104.0, 112.0, 120.0,
128.0, 144.0, 160.0, 176.0,
192.0, 208.0, 224.0, 240.0,
256.0, 288.0, 320.0, 352.0,
384.0, 416.0, 448.0, std::numeric_limits<float>::quiet_NaN(),
-0.0, -0.001953125, -0.00390625, -0.005859375,
-0.0078125, -0.009765625, -0.01171875, -0.013671875,
-0.015625, -0.017578125, -0.01953125, -0.021484375,
-0.0234375, -0.025390625, -0.02734375, -0.029296875,
-0.03125, -0.03515625, -0.0390625, -0.04296875,
-0.046875, -0.05078125, -0.0546875, -0.05859375,
-0.0625, -0.0703125, -0.078125, -0.0859375,
-0.09375, -0.1015625, -0.109375, -0.1171875,
-0.125, -0.140625, -0.15625, -0.171875,
-0.1875, -0.203125, -0.21875, -0.234375,
-0.25, -0.28125, -0.3125, -0.34375,
-0.375, -0.40625, -0.4375, -0.46875,
-0.5, -0.5625, -0.625, -0.6875,
-0.75, -0.8125, -0.875, -0.9375,
-1.0, -1.125, -1.25, -1.375,
-1.5, -1.625, -1.75, -1.875,
-2.0, -2.25, -2.5, -2.75,
-3.0, -3.25, -3.5, -3.75,
-4.0, -4.5, -5.0, -5.5,
-6.0, -6.5, -7.0, -7.5,
-8.0, -9.0, -10.0, -11.0,
-12.0, -13.0, -14.0, -15.0,
-16.0, -18.0, -20.0, -22.0,
-24.0, -26.0, -28.0, -30.0,
-32.0, -36.0, -40.0, -44.0,
-48.0, -52.0, -56.0, -60.0,
-64.0, -72.0, -80.0, -88.0,
-96.0, -104.0, -112.0, -120.0,
-128.0, -144.0, -160.0, -176.0,
-192.0, -208.0, -224.0, -240.0,
-256.0, -288.0, -320.0, -352.0,
-384.0, -416.0, -448.0, std::numeric_limits<float>::quiet_NaN(),
};
return e4m3fnuz_lut[input];
}
TEST_CASE(test_fp8_cast_to_float)
{
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx::fp8::fp8e4m3fn fp8_val(bit_val, migraphx::fp8::fp8e4m3fn::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fn_to_fp32_value(bit_val)))
{
return true;
}
return migraphx::float_equal(float(fp8_val), fp8e4m3fn_to_fp32_value(bit_val));
})});
}
TEST_CASE(test_fp8_cast_from_float)
{
std::unordered_map<float, uint8_t> test_vals = {
{{512, 0x7e}, {-512, 0xfe}, {448, 0x7e}, {-448, 0xfe},
{256, 0x78}, {-256, 0xf8}, {240, 0x77}, {-240, 0xf7},
{1e-07, 0x0}, {1e+07, 0x7e}, {1, 0x38}, {-1, 0xb8},
{0.1, 0x1d}, {0.11, 0x1e}, {0.111, 0x1e}, {0.1111, 0x1e},
{-0.1, 0x9d}, {-0.11, 0x9e}, {-0.111, 0x9e}, {-0.1111, 0x9e},
{0.2, 0x25}, {2, 0x40}, {20, 0x5a}, {200, 0x74},
{-0.2, 0xa5}, {-2, 0xc0}, {-20, 0xda}, {-200, 0xf4},
{0.5, 0x30}, {-0.5, 0xb0}, {1.17549e-38, 0x0}, {1.4013e-45, 0x0},
{0.0078125, 0x4}, {-0.0078125, 0x84}, {0.000976562, 0x0}, {-0.000976562, 0x80},
{0.000488281, 0x0}, {-0.000488281, 0x80}}};
EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) {
return migraphx::float_equal(
migraphx::fp8::fp8e4m3fn(sample.first),
migraphx::fp8::fp8e4m3fn(sample.second, migraphx::fp8::fp8e4m3fn::from_bits()));
})});
}
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx::fp8::fp8e4m3fn fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
migraphx::fp8::fp8e4m3fn fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero is preserved for fp8e4m3fn
EXPECT(migraphx::float_equal(nzero, float(fp8_nzero)));
}
TEST_CASE(test_pos_zero_eq_neg_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx::fp8::fp8e5m2 fp8_nzero(nzero);
migraphx::fp8::fp8e5m2 fp8_pzero(pzero);
EXPECT(fp8_nzero == fp8_pzero);
}
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx::fp8::fp8e4m3fn fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx::fp8::fp8e4m3fn>::quiet_NaN();
migraphx::fp8::fp8e4m3fn fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fn::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_1)
{
float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fn, it gets clipped to max()
migraphx::fp8::fp8e4m3fn fp8_max(finf);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_infinity_2)
{
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fn, it gets clipped to lowest
migraphx::fp8::fp8e4m3fn fp8_lowest(finf);
EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest()});
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx::fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max();
migraphx::fp8::fp8e4m3fn fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx::fp8::fp8e4m3fn fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest();
migraphx::fp8::fp8e4m3fn fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e4m3fn>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e4m3fn>::max()));
}
TEST_CASE(test_isfinite)
{
EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fn(0.0)));
EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fn(-0.0)));
EXPECT(not std::isfinite(
migraphx::fp8::fp8e4m3fn(std::numeric_limits<migraphx::fp8::fp8e4m3fn>::quiet_NaN())));
}
TEST_CASE(test_no_infinity)
{
EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e4m3fn>::has_infinity});
}
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e4m3fn(-1.0);
auto b = migraphx::fp8::fp8e4m3fn(1.0);
auto c = migraphx::fp8::fp8e4m3fn(0.0);
auto d = migraphx::fp8::fp8e4m3fn(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e4m3fn(10.0);
auto f = migraphx::fp8::fp8e4m3fn(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool{f <= e});
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
TEST_CASE(test_fabs)
{
auto a = migraphx::fp8::fp8e4m3fn(-1.0);
auto b = migraphx::fp8::fp8e4m3fn(1.0);
EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a)));
}
TEST_CASE(test_stream_op)
{
auto a = migraphx::fp8::fp8e4m3fn(-1.0);
std::stringstream ss;
ss << a;
EXPECT(std::string("-1") == ss.str());
ss = std::stringstream();
auto b = std::numeric_limits<migraphx::fp8::fp8e4m3fn>::quiet_NaN();
ss << b;
EXPECT(std::string("nan") == ss.str());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
float fp8e4m3fnuz_to_fp32_value(uint8_t input)
{
constexpr std::array<float, 256> e4m3fnuz_lut = {
0.0f, 0.0009765625f, 0.001953125f,
0.0029296875f, 0.00390625f, 0.0048828125f,
0.005859375f, 0.0068359375f, 0.0078125f,
0.0087890625f, 0.009765625f, 0.0107421875f,
0.01171875f, 0.0126953125f, 0.013671875f,
0.0146484375f, 0.015625f, 0.017578125f,
0.01953125f, 0.021484375f, 0.0234375f,
0.025390625f, 0.02734375f, 0.029296875f,
0.03125f, 0.03515625f, 0.0390625f,
0.04296875f, 0.046875f, 0.05078125f,
0.0546875f, 0.05859375f, 0.0625f,
0.0703125f, 0.078125f, 0.0859375f,
0.09375f, 0.1015625f, 0.109375f,
0.1171875f, 0.125f, 0.140625f,
0.15625f, 0.171875f, 0.1875f,
0.203125f, 0.21875f, 0.234375f,
0.25f, 0.28125f, 0.3125f,
0.34375f, 0.375f, 0.40625f,
0.4375f, 0.46875f, 0.5f,
0.5625f, 0.625f, 0.6875f,
0.75f, 0.8125f, 0.875f,
0.9375f, 1.0f, 1.125f,
1.25f, 1.375f, 1.5f,
1.625f, 1.75f, 1.875f,
2.0f, 2.25f, 2.5f,
2.75f, 3.0f, 3.25f,
3.5f, 3.75f, 4.0f,
4.5f, 5.0f, 5.5f,
6.0f, 6.5f, 7.0f,
7.5f, 8.0f, 9.0f,
10.0f, 11.0f, 12.0f,
13.0f, 14.0f, 15.0f,
16.0f, 18.0f, 20.0f,
22.0f, 24.0f, 26.0f,
28.0f, 30.0f, 32.0f,
36.0f, 40.0f, 44.0f,
48.0f, 52.0f, 56.0f,
60.0f, 64.0f, 72.0f,
80.0f, 88.0f, 96.0f,
104.0f, 112.0f, 120.0f,
128.0f, 144.0f, 160.0f,
176.0f, 192.0f, 208.0f,
224.0f, 240.0f, std::numeric_limits<float>::quiet_NaN(),
-0.0009765625f, -0.001953125f, -0.0029296875f,
-0.00390625f, -0.0048828125f, -0.005859375f,
-0.0068359375f, -0.0078125f, -0.0087890625f,
-0.009765625f, -0.0107421875f, -0.01171875f,
-0.0126953125f, -0.013671875f, -0.0146484375f,
-0.015625f, -0.017578125f, -0.01953125f,
-0.021484375f, -0.0234375f, -0.025390625f,
-0.02734375f, -0.029296875f, -0.03125f,
-0.03515625f, -0.0390625f, -0.04296875f,
-0.046875f, -0.05078125f, -0.0546875f,
-0.05859375f, -0.0625f, -0.0703125f,
-0.078125f, -0.0859375f, -0.09375f,
-0.1015625f, -0.109375f, -0.1171875f,
-0.125f, -0.140625f, -0.15625f,
-0.171875f, -0.1875f, -0.203125f,
-0.21875f, -0.234375f, -0.25f,
-0.28125f, -0.3125f, -0.34375f,
-0.375f, -0.40625f, -0.4375f,
-0.46875f, -0.5f, -0.5625f,
-0.625f, -0.6875f, -0.75f,
-0.8125f, -0.875f, -0.9375f,
-1.0f, -1.125f, -1.25f,
-1.375f, -1.5f, -1.625f,
-1.75f, -1.875f, -2.0f,
-2.25f, -2.5f, -2.75f,
-3.0f, -3.25f, -3.5f,
-3.75f, -4.0f, -4.5f,
-5.0f, -5.5f, -6.0f,
-6.5f, -7.0f, -7.5f,
-8.0f, -9.0f, -10.0f,
-11.0f, -12.0f, -13.0f,
-14.0f, -15.0f, -16.0f,
-18.0f, -20.0f, -22.0f,
-24.0f, -26.0f, -28.0f,
-30.0f, -32.0f, -36.0f,
-40.0f, -44.0f, -48.0f,
-52.0f, -56.0f, -60.0f,
-64.0f, -72.0f, -80.0f,
-88.0f, -96.0f, -104.0f,
-112.0f, -120.0f, -128.0f,
-144.0f, -160.0f, -176.0f,
-192.0f, -208.0f, -224.0f,
-240.0f,
};
return e4m3fnuz_lut[input];
}
TEST_CASE(test_fp8_cast_to_float)
{
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx::fp8::fp8e4m3fnuz fp8_val(bit_val, migraphx::fp8::fp8e4m3fnuz::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fnuz_to_fp32_value(bit_val)))
{
return true;
}
return migraphx::float_equal(float(fp8_val), fp8e4m3fnuz_to_fp32_value(bit_val));
})});
}
TEST_CASE(test_fp8_cast_from_float)
{
std::unordered_map<float, uint8_t> test_vals = {{256, 0x7f}, {-256, 0xff},
{240, 0x7f}, {-240, 0xff},
{1e-07, 0x0}, {1e+07, 0x7f},
{1, 0x40}, {-1, 0xc0},
{0.1, 0x25}, {0.11, 0x26},
{0.111, 0x26}, {0.1111, 0x26},
{-0.1, 0xa5}, {-0.11, 0xa6},
{-0.111, 0xa6}, {-0.1111, 0xa6},
{0.2, 0x2d}, {2, 0x48},
{20, 0x62}, {200, 0x7c},
{-0.2, 0xad}, {-2, 0xc8},
{-20, 0xe2}, {-200, 0xfc},
{0.5, 0x38}, {-0.5, 0xb8},
{1.17549e-38, 0x0}, {1.4013e-45, 0x0},
{0.00390625, 0x4}, {-0.00390625, 0x84},
{0.00195312, 0x2}, {-0.00195312, 0x82},
{0.000976562, 0x1}, {-0.000976562, 0x81},
{0.000488281, 0x0}, {-0.000488281, 0x0}};
EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) {
return migraphx::float_equal(
migraphx::fp8::fp8e4m3fnuz(sample.first),
migraphx::fp8::fp8e4m3fnuz(sample.second, migraphx::fp8::fp8e4m3fnuz::from_bits()));
})});
}
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx::fp8::fp8e4m3fnuz fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx::fp8::fp8e4m3fnuz fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero gets converted to positive zero
EXPECT(migraphx::float_equal(pzero, float(fp8_nzero)));
}
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::quiet_NaN();
migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fnuz::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_1)
{
float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fnuz it gets clipped to Nans
migraphx::fp8::fp8e4m3fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_2)
{
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e4m3fnuz it gets clipped to NaNs
migraphx::fp8::fp8e4m3fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx::fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max();
migraphx::fp8::fp8e4m3fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx::fp8::fp8e4m3fnuz fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest();
migraphx::fp8::fp8e4m3fnuz fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::max()));
}
TEST_CASE(test_isfinite)
{
EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fnuz(0.0)));
EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fnuz(-0.0)));
EXPECT(not std::isfinite(
migraphx::fp8::fp8e4m3fnuz(std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::quiet_NaN())));
}
TEST_CASE(test_no_infinity)
{
EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::has_infinity});
}
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e4m3fnuz(-1.0);
auto b = migraphx::fp8::fp8e4m3fnuz(1.0);
auto c = migraphx::fp8::fp8e4m3fnuz(0.0);
auto d = migraphx::fp8::fp8e4m3fnuz(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e4m3fnuz(10.0);
auto f = migraphx::fp8::fp8e4m3fnuz(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool{f <= e});
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
TEST_CASE(test_fabs)
{
auto a = migraphx::fp8::fp8e4m3fnuz(-1.0);
auto b = migraphx::fp8::fp8e4m3fnuz(1.0);
EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a)));
}
TEST_CASE(test_stream_op)
{
auto a = migraphx::fp8::fp8e4m3fnuz(-1.0);
std::stringstream ss;
ss << a;
EXPECT(std::string("-1") == ss.str());
ss = std::stringstream();
auto b = std::numeric_limits<migraphx::fp8::fp8e4m3fnuz>::quiet_NaN();
ss << b;
EXPECT(std::string("nan") == ss.str());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
#include <sstream>
float fp8e5m2_to_fp32_value(uint8_t input)
{
constexpr std::array<float, 256> e4m3fnuz_lut = {
0.0,
1.52587890625e-05,
3.0517578125e-05,
4.57763671875e-05,
6.103515625e-05,
7.62939453125e-05,
9.1552734375e-05,
0.0001068115234375,
0.0001220703125,
0.000152587890625,
0.00018310546875,
0.000213623046875,
0.000244140625,
0.00030517578125,
0.0003662109375,
0.00042724609375,
0.00048828125,
0.0006103515625,
0.000732421875,
0.0008544921875,
0.0009765625,
0.001220703125,
0.00146484375,
0.001708984375,
0.001953125,
0.00244140625,
0.0029296875,
0.00341796875,
0.00390625,
0.0048828125,
0.005859375,
0.0068359375,
0.0078125,
0.009765625,
0.01171875,
0.013671875,
0.015625,
0.01953125,
0.0234375,
0.02734375,
0.03125,
0.0390625,
0.046875,
0.0546875,
0.0625,
0.078125,
0.09375,
0.109375,
0.125,
0.15625,
0.1875,
0.21875,
0.25,
0.3125,
0.375,
0.4375,
0.5,
0.625,
0.75,
0.875,
1.0,
1.25,
1.5,
1.75,
2.0,
2.5,
3.0,
3.5,
4.0,
5.0,
6.0,
7.0,
8.0,
10.0,
12.0,
14.0,
16.0,
20.0,
24.0,
28.0,
32.0,
40.0,
48.0,
56.0,
64.0,
80.0,
96.0,
112.0,
128.0,
160.0,
192.0,
224.0,
256.0,
320.0,
384.0,
448.0,
512.0,
640.0,
768.0,
896.0,
1024.0,
1280.0,
1536.0,
1792.0,
2048.0,
2560.0,
3072.0,
3584.0,
4096.0,
5120.0,
6144.0,
7168.0,
8192.0,
10240.0,
12288.0,
14336.0,
16384.0,
20480.0,
24576.0,
28672.0,
32768.0,
40960.0,
49152.0,
57344.0,
std::numeric_limits<float>::infinity(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-0.0,
-1.52587890625e-05,
-3.0517578125e-05,
-4.57763671875e-05,
-6.103515625e-05,
-7.62939453125e-05,
-9.1552734375e-05,
-0.0001068115234375,
-0.0001220703125,
-0.000152587890625,
-0.00018310546875,
-0.000213623046875,
-0.000244140625,
-0.00030517578125,
-0.0003662109375,
-0.00042724609375,
-0.00048828125,
-0.0006103515625,
-0.000732421875,
-0.0008544921875,
-0.0009765625,
-0.001220703125,
-0.00146484375,
-0.001708984375,
-0.001953125,
-0.00244140625,
-0.0029296875,
-0.00341796875,
-0.00390625,
-0.0048828125,
-0.005859375,
-0.0068359375,
-0.0078125,
-0.009765625,
-0.01171875,
-0.013671875,
-0.015625,
-0.01953125,
-0.0234375,
-0.02734375,
-0.03125,
-0.0390625,
-0.046875,
-0.0546875,
-0.0625,
-0.078125,
-0.09375,
-0.109375,
-0.125,
-0.15625,
-0.1875,
-0.21875,
-0.25,
-0.3125,
-0.375,
-0.4375,
-0.5,
-0.625,
-0.75,
-0.875,
-1.0,
-1.25,
-1.5,
-1.75,
-2.0,
-2.5,
-3.0,
-3.5,
-4.0,
-5.0,
-6.0,
-7.0,
-8.0,
-10.0,
-12.0,
-14.0,
-16.0,
-20.0,
-24.0,
-28.0,
-32.0,
-40.0,
-48.0,
-56.0,
-64.0,
-80.0,
-96.0,
-112.0,
-128.0,
-160.0,
-192.0,
-224.0,
-256.0,
-320.0,
-384.0,
-448.0,
-512.0,
-640.0,
-768.0,
-896.0,
-1024.0,
-1280.0,
-1536.0,
-1792.0,
-2048.0,
-2560.0,
-3072.0,
-3584.0,
-4096.0,
-5120.0,
-6144.0,
-7168.0,
-8192.0,
-10240.0,
-12288.0,
-14336.0,
-16384.0,
-20480.0,
-24576.0,
-28672.0,
-32768.0,
-40960.0,
-49152.0,
-57344.0,
-1.0f * std::numeric_limits<float>::infinity(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
};
return e4m3fnuz_lut[input];
}
TEST_CASE(test_fp8_cast_to_float)
{
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx::fp8::fp8e5m2 fp8_val(bit_val, migraphx::fp8::fp8e5m2::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2_to_fp32_value(bit_val)))
{
return true;
}
else if(std::isinf(float(fp8_val)) and std::isinf(fp8e5m2_to_fp32_value(bit_val)))
{
return true;
}
return migraphx::float_equal(float(fp8_val), fp8e5m2_to_fp32_value(bit_val));
})});
}
TEST_CASE(test_fp8_cast_from_float)
{
std::unordered_map<float, uint8_t> test_vals = {
{-60000, 0xfb},
{-57344, 0xfb},
{-448, 0xdf},
{-256, 0xdc},
{-240, 0xdc},
{-200, 0xda},
{-20, 0xcd},
{-2, 0xc0},
{-1, 0xbc},
{-0.5, 0xb8},
{-0.2, 0xb2},
{-0.1111, 0xaf},
{-0.111, 0xaf},
{-0.11, 0xaf},
{-0.1, 0xae},
{6.10351e-05, 0x4},
{-6.10351e-05, 0x84},
{3.05176e-05, 0x2},
{-3.05176e-05, 0x82},
{1.52588e-05, 0x1},
{-1.52588e-05, 0x81},
{7.62939e-06, 0x0},
{-7.62939e-06, 0x80},
{0.1, 0x2e},
{0.11, 0x2f},
{0.111, 0x2f},
{0.1111, 0x2f},
{0.2, 0x32},
{0.5, 0x38},
{1, 0x3c},
{2, 0x40},
{20, 0x4d},
{200, 0x5a},
{240, 0x5c},
{256, 0x5c},
{448, 0x5f},
{57344, 0x7b},
{60000, 0x7b},
{1e+07, 0x7b},
};
EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) {
return migraphx::float_equal(
migraphx::fp8::fp8e5m2(sample.first),
migraphx::fp8::fp8e5m2(sample.second, migraphx::fp8::fp8e5m2::from_bits()));
})});
}
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx::fp8::fp8e5m2 fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
migraphx::fp8::fp8e5m2 fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero is preserved for fp8e5m2
EXPECT(migraphx::float_equal(nzero, float(fp8_nzero)));
}
TEST_CASE(test_pos_zero_eq_neg_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx::fp8::fp8e5m2 fp8_nzero(nzero);
migraphx::fp8::fp8e5m2 fp8_pzero(pzero);
EXPECT(fp8_nzero == fp8_pzero);
}
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx::fp8::fp8e5m2 fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx::fp8::fp8e5m2>::quiet_NaN();
migraphx::fp8::fp8e5m2 fp8_nan(fnan.data, migraphx::fp8::fp8e5m2::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_1)
{
// float infinity should get clipped to max
float finf = std::numeric_limits<float>::infinity();
migraphx::fp8::fp8e5m2 fp8_max(finf);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2>::max());
}
TEST_CASE(test_infinity_2)
{
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e5m2, it gets clipped to lowest
migraphx::fp8::fp8e5m2 fp8_lowest(finf);
EXPECT(bool{fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest()});
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx::fp8::fp8e5m2 fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e5m2>::max();
migraphx::fp8::fp8e5m2 fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx::fp8::fp8e5m2 fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest();
migraphx::fp8::fp8e5m2 fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e5m2>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e5m2>::max()));
}
TEST_CASE(test_isfinite)
{
EXPECT(std::isfinite(migraphx::fp8::fp8e5m2(0.0)));
EXPECT(std::isfinite(migraphx::fp8::fp8e5m2(-0.0)));
EXPECT(not std::isfinite(
migraphx::fp8::fp8e5m2(std::numeric_limits<migraphx::fp8::fp8e5m2>::quiet_NaN())));
EXPECT(not std::isfinite(std::numeric_limits<migraphx::fp8::fp8e5m2>::infinity()));
// -1.0 * inf is float(-inf) which with clipping/saturation gets converted into fp8::lowest()
EXPECT(std::isfinite(
migraphx::fp8::fp8e5m2(-1.0 * std::numeric_limits<migraphx::fp8::fp8e5m2>::infinity())));
EXPECT(not std::isfinite(migraphx::fp8::fp8e5m2(0xFC, migraphx::fp8::fp8e5m2::from_bits())));
}
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e5m2(-1.0);
auto b = migraphx::fp8::fp8e5m2(1.0);
auto c = migraphx::fp8::fp8e5m2(0.0);
auto d = migraphx::fp8::fp8e5m2(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e5m2(10.0);
auto f = migraphx::fp8::fp8e5m2(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool{f <= e});
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
TEST_CASE(test_fabs)
{
auto a = migraphx::fp8::fp8e5m2(-1.0);
auto b = migraphx::fp8::fp8e5m2(1.0);
EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a)));
}
TEST_CASE(test_stream_op)
{
auto a = migraphx::fp8::fp8e5m2(-1.0);
std::stringstream ss;
ss << a;
EXPECT(std::string("-1") == ss.str());
ss = std::stringstream();
auto b = std::numeric_limits<migraphx::fp8::fp8e5m2>::quiet_NaN();
ss << b;
EXPECT(std::string("nan") == ss.str());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
float fp8e5m2fnuz_to_fp32_value(uint8_t input)
{
constexpr std::array<float, 256> e4m3fnuz_lut = {
0.0,
7.62939453125e-06,
1.52587890625e-05,
2.288818359375e-05,
3.0517578125e-05,
3.814697265625e-05,
4.57763671875e-05,
5.340576171875e-05,
6.103515625e-05,
7.62939453125e-05,
9.1552734375e-05,
0.0001068115234375,
0.0001220703125,
0.000152587890625,
0.00018310546875,
0.000213623046875,
0.000244140625,
0.00030517578125,
0.0003662109375,
0.00042724609375,
0.00048828125,
0.0006103515625,
0.000732421875,
0.0008544921875,
0.0009765625,
0.001220703125,
0.00146484375,
0.001708984375,
0.001953125,
0.00244140625,
0.0029296875,
0.00341796875,
0.00390625,
0.0048828125,
0.005859375,
0.0068359375,
0.0078125,
0.009765625,
0.01171875,
0.013671875,
0.015625,
0.01953125,
0.0234375,
0.02734375,
0.03125,
0.0390625,
0.046875,
0.0546875,
0.0625,
0.078125,
0.09375,
0.109375,
0.125,
0.15625,
0.1875,
0.21875,
0.25,
0.3125,
0.375,
0.4375,
0.5,
0.625,
0.75,
0.875,
1.0,
1.25,
1.5,
1.75,
2.0,
2.5,
3.0,
3.5,
4.0,
5.0,
6.0,
7.0,
8.0,
10.0,
12.0,
14.0,
16.0,
20.0,
24.0,
28.0,
32.0,
40.0,
48.0,
56.0,
64.0,
80.0,
96.0,
112.0,
128.0,
160.0,
192.0,
224.0,
256.0,
320.0,
384.0,
448.0,
512.0,
640.0,
768.0,
896.0,
1024.0,
1280.0,
1536.0,
1792.0,
2048.0,
2560.0,
3072.0,
3584.0,
4096.0,
5120.0,
6144.0,
7168.0,
8192.0,
10240.0,
12288.0,
14336.0,
16384.0,
20480.0,
24576.0,
28672.0,
32768.0,
40960.0,
49152.0,
57344.0,
std::numeric_limits<float>::quiet_NaN(),
-7.62939453125e-06,
-1.52587890625e-05,
-2.288818359375e-05,
-3.0517578125e-05,
-3.814697265625e-05,
-4.57763671875e-05,
-5.340576171875e-05,
-6.103515625e-05,
-7.62939453125e-05,
-9.1552734375e-05,
-0.0001068115234375,
-0.0001220703125,
-0.000152587890625,
-0.00018310546875,
-0.000213623046875,
-0.000244140625,
-0.00030517578125,
-0.0003662109375,
-0.00042724609375,
-0.00048828125,
-0.0006103515625,
-0.000732421875,
-0.0008544921875,
-0.0009765625,
-0.001220703125,
-0.00146484375,
-0.001708984375,
-0.001953125,
-0.00244140625,
-0.0029296875,
-0.00341796875,
-0.00390625,
-0.0048828125,
-0.005859375,
-0.0068359375,
-0.0078125,
-0.009765625,
-0.01171875,
-0.013671875,
-0.015625,
-0.01953125,
-0.0234375,
-0.02734375,
-0.03125,
-0.0390625,
-0.046875,
-0.0546875,
-0.0625,
-0.078125,
-0.09375,
-0.109375,
-0.125,
-0.15625,
-0.1875,
-0.21875,
-0.25,
-0.3125,
-0.375,
-0.4375,
-0.5,
-0.625,
-0.75,
-0.875,
-1.0,
-1.25,
-1.5,
-1.75,
-2.0,
-2.5,
-3.0,
-3.5,
-4.0,
-5.0,
-6.0,
-7.0,
-8.0,
-10.0,
-12.0,
-14.0,
-16.0,
-20.0,
-24.0,
-28.0,
-32.0,
-40.0,
-48.0,
-56.0,
-64.0,
-80.0,
-96.0,
-112.0,
-128.0,
-160.0,
-192.0,
-224.0,
-256.0,
-320.0,
-384.0,
-448.0,
-512.0,
-640.0,
-768.0,
-896.0,
-1024.0,
-1280.0,
-1536.0,
-1792.0,
-2048.0,
-2560.0,
-3072.0,
-3584.0,
-4096.0,
-5120.0,
-6144.0,
-7168.0,
-8192.0,
-10240.0,
-12288.0,
-14336.0,
-16384.0,
-20480.0,
-24576.0,
-28672.0,
-32768.0,
-40960.0,
-49152.0,
-57344.0,
};
return e4m3fnuz_lut[input];
}
TEST_CASE(test_fp8_cast_to_float)
{
std::vector<uint8_t> bit_vals(256);
std::iota(bit_vals.begin(), bit_vals.end(), 0);
EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) {
migraphx::fp8::fp8e5m2fnuz fp8_val(bit_val, migraphx::fp8::fp8e5m2fnuz::from_bits());
if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2fnuz_to_fp32_value(bit_val)))
{
return true;
}
return migraphx::float_equal(float(fp8_val), fp8e5m2fnuz_to_fp32_value(bit_val));
})});
}
TEST_CASE(test_fp8_cast_from_float)
{
std::unordered_map<float, uint8_t> test_vals = {
{57344, 0x7f}, {-57344, 0xff}, {60000, 0x7f}, {-60000, 0xff},
{448, 0x63}, {-448, 0xe3}, {256, 0x60}, {-256, 0xe0},
{240, 0x60}, {-240, 0xe0}, {3.05176e-05, 0x4}, {-3.05176e-05, 0x84},
{1.52588e-05, 0x2}, {-1.52588e-05, 0x82}, {7.62939e-06, 0x1}, {-7.62939e-06, 0x81},
{3.81469e-06, 0x0}, {-3.81469e-06, 0x0}, {1e+07, 0x7f}, {1, 0x40},
{-1, 0xc0}, {0.1, 0x32}, {0.11, 0x33}, {0.111, 0x33},
{0.1111, 0x33}, {-0.1, 0xb2}, {-0.11, 0xb3}, {-0.111, 0xb3},
{-0.1111, 0xb3}, {0.2, 0x36}, {2, 0x44}, {20, 0x51},
{200, 0x5e}, {-0.2, 0xb6}, {-2, 0xc4}, {-20, 0xd1},
{-200, 0xde}, {0.5, 0x3c}, {-0.5, 0xbc}, {1.17549e-38, 0x0},
{1.4013e-45, 0x0},
};
EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) {
return migraphx::float_equal(
migraphx::fp8::fp8e5m2fnuz(sample.first),
migraphx::fp8::fp8e5m2fnuz(sample.second, migraphx::fp8::fp8e5m2fnuz::from_bits()));
})});
}
TEST_CASE(test_positive_zero)
{
float zero = 0.0;
migraphx::fp8::fp8e5m2fnuz fp8_zero(zero);
EXPECT(fp8_zero.is_zero());
EXPECT(migraphx::float_equal(zero, float(fp8_zero)));
}
TEST_CASE(test_negative_zero)
{
float nzero = -0.0;
float pzero = 0.0;
migraphx::fp8::fp8e5m2fnuz fp8_nzero(nzero);
EXPECT(fp8_nzero.is_zero());
// negative zero gets converted to positive zero
EXPECT(migraphx::float_equal(pzero, float(fp8_nzero)));
}
TEST_CASE(test_nan_1)
{
float fnan = std::numeric_limits<float>::quiet_NaN();
migraphx::fp8::fp8e5m2fnuz fp8_nan(fnan);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
}
TEST_CASE(test_nan_2)
{
auto fnan = std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::quiet_NaN();
migraphx::fp8::fp8e5m2fnuz fp8_nan(fnan.data, migraphx::fp8::fp8e5m2fnuz::from_bits());
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(fp8_nan));
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_1)
{
float finf = std::numeric_limits<float>::infinity();
// no inf in fp8e5m2fnuz it gets clipped to Nans
migraphx::fp8::fp8e5m2fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_infinity_2)
{
// neg inf
float finf = -1.0 * std::numeric_limits<float>::infinity();
// no inf in fp8e5m2fnuz it gets clipped to NaNs
migraphx::fp8::fp8e5m2fnuz fp8_nan(finf);
EXPECT(fp8_nan.is_nan());
EXPECT(std::isnan(float(fp8_nan)));
}
TEST_CASE(test_numeric_max_1)
{
float fmax = std::numeric_limits<float>::max();
migraphx::fp8::fp8e5m2fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max());
}
TEST_CASE(test_numeric_max_2)
{
// gets clipped to max
float fmax = 2 * std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max();
migraphx::fp8::fp8e5m2fnuz fp8_max(fmax);
EXPECT(fp8_max == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max());
}
TEST_CASE(test_numeric_lowest_1)
{
float flowest = std::numeric_limits<float>::lowest();
migraphx::fp8::fp8e5m2fnuz fp8_lowest(flowest);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest());
}
TEST_CASE(test_numeric_lowest_2)
{
// gets clipped to lowest
float fmin = 2.0 * std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest();
migraphx::fp8::fp8e5m2fnuz fp8_lowest(fmin);
EXPECT(fp8_lowest == std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest());
}
TEST_CASE(test_max_eq_lowest)
{
EXPECT(migraphx::float_equal(std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::lowest(),
-1 * std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::max()));
}
TEST_CASE(test_isfinite)
{
EXPECT(std::isfinite(migraphx::fp8::fp8e5m2fnuz(0.0)));
EXPECT(std::isfinite(migraphx::fp8::fp8e5m2fnuz(-0.0)));
EXPECT(not std::isfinite(
migraphx::fp8::fp8e5m2fnuz(std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::quiet_NaN())));
}
TEST_CASE(test_no_infinity)
{
EXPECT(not bool{std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::has_infinity});
}
TEST_CASE(test_binary_ops)
{
auto a = migraphx::fp8::fp8e5m2fnuz(-1.0);
auto b = migraphx::fp8::fp8e5m2fnuz(1.0);
auto c = migraphx::fp8::fp8e5m2fnuz(0.0);
auto d = migraphx::fp8::fp8e5m2fnuz(-0.0);
EXPECT(migraphx::float_equal((c + d), c));
EXPECT(migraphx::float_equal((c + d), d));
EXPECT(migraphx::float_equal((a + b), c));
EXPECT(migraphx::float_equal((a + b), d));
auto e = migraphx::fp8::fp8e5m2fnuz(10.0);
auto f = migraphx::fp8::fp8e5m2fnuz(-10.0);
EXPECT(bool{e > f});
EXPECT(bool{f < e});
EXPECT(bool{f <= e});
EXPECT(bool{e >= f});
EXPECT(bool{e <= e});
EXPECT(bool{f >= f});
EXPECT(not migraphx::float_equal(f, e));
}
TEST_CASE(test_fabs)
{
auto a = migraphx::fp8::fp8e5m2fnuz(-1.0);
auto b = migraphx::fp8::fp8e5m2fnuz(1.0);
EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a)));
}
TEST_CASE(test_stream_op)
{
auto a = migraphx::fp8::fp8e5m2fnuz(-1.0);
std::stringstream ss;
ss << a;
EXPECT(std::string("-1") == ss.str());
ss = std::stringstream();
auto b = std::numeric_limits<migraphx::fp8::fp8e5m2fnuz>::quiet_NaN();
ss << b;
EXPECT(std::string("nan") == ss.str());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -414,8 +414,8 @@ TEST_CASE(add_reshape_add_nonstandard) ...@@ -414,8 +414,8 @@ TEST_CASE(add_reshape_add_nonstandard)
auto y = mm->add_parameter("y", s1); auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2); auto z = mm->add_parameter("z", s2);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), add1); auto reshape =
auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), c); mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z); auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z);
mm->add_return({add2}); mm->add_return({add2});
} }
...@@ -426,10 +426,8 @@ TEST_CASE(add_reshape_add_nonstandard) ...@@ -426,10 +426,8 @@ TEST_CASE(add_reshape_add_nonstandard)
auto x = mm->add_parameter("x", s1); auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1); auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2); auto z = mm->add_parameter("z", s2);
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x); auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), x);
auto cy = mm->add_instruction(migraphx::make_op("contiguous"), y); auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), y);
auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), cx);
auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), cy);
auto z2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), z); auto z2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), z);
auto fadd = auto fadd =
add_pointwise(p2, "main:pointwise0", {x2, y2, z2}, [=](auto* pm, const auto& inputs) { add_pointwise(p2, "main:pointwise0", {x2, y2, z2}, [=](auto* pm, const auto& inputs) {
...@@ -466,10 +464,8 @@ TEST_CASE(add_unsqueeze_add_nonstandard) ...@@ -466,10 +464,8 @@ TEST_CASE(add_unsqueeze_add_nonstandard)
auto x = mm->add_parameter("x", s1); auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1); auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2); auto z = mm->add_parameter("z", s2);
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x); auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), x);
auto cy = mm->add_instruction(migraphx::make_op("contiguous"), y); auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), y);
auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), cx);
auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), cy);
auto fadd = auto fadd =
add_pointwise(p2, "main:pointwise0", {x2, y2, z}, [=](auto* pm, const auto& inputs) { add_pointwise(p2, "main:pointwise0", {x2, y2, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
......
...@@ -237,12 +237,12 @@ TEST_CASE(code_object_hip) ...@@ -237,12 +237,12 @@ TEST_CASE(code_object_hip)
std::vector<migraphx::shape> expected_inputs = {input, input}; std::vector<migraphx::shape> expected_inputs = {input, input};
auto co = migraphx::make_op("gpu::code_object", auto co = migraphx::make_op("gpu::code_object",
{{"code_object", migraphx::value::binary{binaries.front()}}, {{"code_object", migraphx::value::binary{binaries.front()}},
{"symbol_name", "add_2"}, {"symbol_name", "add_2"},
{"global", input.elements()}, {"global", input.elements()},
{"local", 1024}, {"local", 1024},
{"expected_inputs", migraphx::to_value(expected_inputs)}, {"expected_inputs", migraphx::to_value(expected_inputs)},
{"output", migraphx::to_value(input)}}); {"output", migraphx::to_value(input)}});
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -348,7 +348,10 @@ TEST_CASE(compile_math) ...@@ -348,7 +348,10 @@ TEST_CASE(compile_math)
auto vec_sizes = {2, 4, 6}; auto vec_sizes = {2, 4, 6};
for(auto&& t : migraphx::shape::types()) for(auto&& t : migraphx::shape::types())
{ {
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) if(contains({migraphx::shape::bool_type,
migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::tuple_type},
t))
continue; continue;
auto name = migraphx::shape::cpp_type(t); auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type) if(t == migraphx::shape::half_type)
...@@ -396,7 +399,10 @@ TEST_CASE(assert_type_min_max) ...@@ -396,7 +399,10 @@ TEST_CASE(assert_type_min_max)
migraphx::gpu::hip_compile_options options; migraphx::gpu::hip_compile_options options;
for(auto&& t : migraphx::shape::types()) for(auto&& t : migraphx::shape::types())
{ {
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) if(contains({migraphx::shape::bool_type,
migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::tuple_type},
t))
continue; continue;
auto name = migraphx::shape::cpp_type(t); auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type) if(t == migraphx::shape::half_type)
......
4a8203033930da506b356cdaf88b1531d8d8fca3 a5537f2f563d4975c7e6121a7eb260bbbfd9455a
...@@ -4484,6 +4484,177 @@ def lrn_test(): ...@@ -4484,6 +4484,177 @@ def lrn_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def lstm_bi_layout_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [2, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [2, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 2, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 2, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [2, 60])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 2, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', '', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='bidirectional',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [cellout])
@onnx_test()
def lstm_bi_layout_last_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [2, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [2, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 2, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 2, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [2, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 2, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 2, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs', 'output'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='bidirectional',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs, output])
@onnx_test()
def lstm_f_layout_hs_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs', 'output'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='forward',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs, output])
@onnx_test()
def lstm_f_layout_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', '', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='forward',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [cellout])
@onnx_test()
def lstm_r_layout_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['hs'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='reverse',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs])
@onnx_test()
def lstm_r_layout_hs_cell_test():
seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10])
r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160])
seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3])
h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20])
c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20])
pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[3, 1, 20])
cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT,
[3, 1, 20])
node = onnx.helper.make_node(
'LSTM',
inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'],
outputs=['', 'output', 'cellout'],
activations=['sigmoid', 'tanh', 'tanh'],
clip=0,
direction='reverse',
hidden_size=20,
input_forget=1,
layout=1)
return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [output, cellout])
@onnx_test() @onnx_test()
def matmul_bmbm_test(): def matmul_bmbm_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7]) m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7])
...@@ -6008,6 +6179,61 @@ def qlinearmatmul_3D_test(): ...@@ -6008,6 +6179,61 @@ def qlinearmatmul_3D_test():
[sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c]) [sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c])
@onnx_test()
def qlinearmul_test():
a = helper.make_tensor_value_info('A', TensorProto.UINT8, [64])
sc_a = helper.make_tensor('A_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_a = helper.make_tensor('A_zero_point', TensorProto.UINT8, [], [0])
b = helper.make_tensor_value_info('B', TensorProto.UINT8, [64])
sc_b = helper.make_tensor('B_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_b = helper.make_tensor('B_zero_point', TensorProto.UINT8, [], [16])
sc_c = helper.make_tensor('C_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_c = helper.make_tensor('C_zero_point', TensorProto.UINT8, [],
[100])
c = helper.make_tensor_value_info('C', TensorProto.UINT8, [64])
node = onnx.helper.make_node(
'QLinearMul',
inputs=[
'A', 'A_scale', 'A_zero_point', 'B', 'B_scale', 'B_zero_point',
'C_scale', 'C_zero_point'
],
outputs=['C'],
)
return ([node], [a, b], [c],
[sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c])
@onnx_test()
def qlinearmul_bcast_test():
a = helper.make_tensor_value_info('A', TensorProto.INT8, [64])
sc_a = helper.make_tensor('A_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_a = helper.make_tensor('A_zero_point', TensorProto.INT8, [], [0])
b = helper.make_tensor_value_info('B', TensorProto.INT8, [1, 1, 64])
sc_b = helper.make_tensor('B_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_b = helper.make_tensor('B_zero_point', TensorProto.INT8, [], [128])
sc_c = helper.make_tensor('C_scale', TensorProto.FLOAT, [], [0.15])
zero_pt_c = helper.make_tensor('C_zero_point', TensorProto.INT8, [], [32])
c = helper.make_tensor_value_info('C', TensorProto.INT8, [1, 1, 64])
node = onnx.helper.make_node(
'QLinearMul',
inputs=[
'A', 'A_scale', 'A_zero_point', 'B', 'B_scale', 'B_zero_point',
'C_scale', 'C_zero_point'
],
outputs=['C'],
)
return ([node], [a, b], [c],
[sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c])
@onnx_test() @onnx_test()
def quantizelinear_test(): def quantizelinear_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5]) arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5])
......
...@@ -1092,6 +1092,115 @@ TEST_CASE(lstm_forward) ...@@ -1092,6 +1092,115 @@ TEST_CASE(lstm_forward)
} }
} }
TEST_CASE(lstm_forward_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 8 args, hs and last output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("lstm_f_layout_hs_test.onnx");
EXPECT(p == prog);
}
// 8 args, cell output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_f_layout_cell_test.onnx");
EXPECT(p == prog);
}
}
// activation functions // activation functions
TEST_CASE(lstm_forward_actv_func) TEST_CASE(lstm_forward_actv_func)
{ {
...@@ -1342,6 +1451,117 @@ TEST_CASE(lstm_reverse) ...@@ -1342,6 +1451,117 @@ TEST_CASE(lstm_reverse)
} }
} }
TEST_CASE(lstm_reverse_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 1; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 8 args, hs output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
auto prog = optimize_onnx("lstm_r_layout_test.onnx");
EXPECT(p == prog);
}
// 8 args, last and cell output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
last_output = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}),
last_output);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_r_layout_hs_cell_test.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_bidirectional) TEST_CASE(lstm_bidirectional)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
...@@ -1594,6 +1814,118 @@ TEST_CASE(lstm_bidirectional) ...@@ -1594,6 +1814,118 @@ TEST_CASE(lstm_bidirectional)
} }
} }
TEST_CASE(lstm_bidirectional_layout)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
int input_forget = 1;
migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}};
migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}};
migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
// 0 activation function
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
auto prog = optimize_onnx("lstm_bi_layout_last_test.onnx");
EXPECT(p == prog);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_parameter("seq", seq_shape);
auto w = mm->add_parameter("w", w_shape);
auto r = mm->add_parameter("r", r_shape);
auto bias = mm->add_parameter("bias", bias_shape);
auto seq_len = mm->add_parameter("seq_len", sl_shape);
auto ih = mm->add_parameter("h0", ih_shape);
auto ic = mm->add_parameter("c0", ih_shape);
auto pph = mm->add_parameter("pph", pph_shape);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hs},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", input_forget}}),
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell);
auto prog = optimize_onnx("lstm_bi_layout_cell_test.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(lstm_bi_actv_funcs) TEST_CASE(lstm_bi_actv_funcs)
{ {
std::size_t sl = 5; // sequence len std::size_t sl = 5; // sequence len
......
...@@ -5754,6 +5754,59 @@ TEST_CASE(qlinearmatmul_2D_test) ...@@ -5754,6 +5754,59 @@ TEST_CASE(qlinearmatmul_2D_test)
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
TEST_CASE(qlinearmul_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto a = mm->add_parameter("A", {migraphx::shape::uint8_type, {64}});
auto b = mm->add_parameter("B", {migraphx::shape::uint8_type, {64}});
auto sc_a = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_a = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {0}});
auto sc_b = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_b = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {16}});
auto sc_c = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_c = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {100}});
auto scale_a_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_a);
auto z_pt_a_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_a);
auto fp_a =
mm->add_instruction(migraphx::make_op("dequantizelinear"), a, scale_a_bcast, z_pt_a_bcast);
auto scale_b_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_b);
auto z_pt_b_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_b);
auto fp_b =
mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scale_b_bcast, z_pt_b_bcast);
auto fp_c = mm->add_instruction(migraphx::make_op("mul"), fp_a, fp_b);
auto scale_c_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_c);
auto z_pt_c_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_c);
auto c =
mm->add_instruction(migraphx::make_op("quantizelinear"), fp_c, scale_c_bcast, z_pt_c_bcast);
mm->add_return({c});
auto prog = migraphx::parse_onnx("qlinearmul_test.onnx");
EXPECT(p.sort() == prog.sort());
}
migraphx::instruction_ref insert_quantizelinear_clip(migraphx::module& m, migraphx::instruction_ref insert_quantizelinear_clip(migraphx::module& m,
const migraphx::instruction_ref ins, const migraphx::instruction_ref ins,
const migraphx::instruction_ref round, const migraphx::instruction_ref round,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment