Commit 8d32c6b8 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into blas_tuning

parents 23cb7917 f25606f9
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(pad_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l0 = mm->add_literal(migraphx::literal{s, {1, 2, 3, 4}});
mm->add_instruction(migraphx::make_op("pad", {{"pads", {1, 1, 1, 1}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(pad_test_asym)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l0 = mm->add_literal(migraphx::literal{s, {1, 2, 3, 4}});
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 1, 1}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(9);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 0, 3, 4, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(pad_test_highest_half)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {2, 2}};
auto l0 = mm->add_literal(migraphx::literal{s, {1, 2, 3, 4}});
mm->add_instruction(
migraphx::make_op("pad",
{{"pads", {1, 1, 1, 1}}, {"value", std::numeric_limits<float>::max()}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
const float x = std::numeric_limits<migraphx::half>::max();
std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(pad_test_lowest_half)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {2, 2}};
auto l0 = mm->add_literal(migraphx::literal{s, {1, 2, 3, 4}});
mm->add_instruction(
migraphx::make_op(
"pad", {{"pads", {1, 1, 1, 1}}, {"value", std::numeric_limits<float>::lowest()}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
const float x = std::numeric_limits<migraphx::half>::lowest();
std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(pad_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2}}, {2, 4, {2}}}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("pad", {{"pads", {1, 1, 1, 1}}}), x);
p.compile(migraphx::make_target("ref"));
std::vector<float> data = {1, 2, 3, 4};
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 2}};
params["x"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(pointwise_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}});
auto* pm = p.create_module("pointwise");
auto x1 = pm->add_parameter("x1", {migraphx::shape::float_type});
auto x2 = pm->add_parameter("x2", {migraphx::shape::float_type});
pm->add_instruction(migraphx::make_op("add"), x1, x2);
mm->add_instruction(migraphx::make_op("pointwise"), {l1, l2}, {pm});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(avgpool_rank3_test)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.3, 0.25, 0.65, 0.7, 0.5, 0.4, 0.4, 0.35};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(avgpool_dyn_test)
{
// Dynamic input, no padding
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"lengths", {2}},
{"padding", {0}},
{"stride", {1}}}),
x);
p.compile(migraphx::make_target("ref"));
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.3, 0.25, 0.65, 0.7, 0.5, 0.4, 0.4, 0.35};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(avgpool_dyn_pad_test)
{
// Dynamic input with explicit padding/
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 3}, {3, 3}, {4, 4}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"lengths", {2}},
{"padding", {1}},
{"stride", {1}}}),
x);
p.compile(migraphx::make_target("ref"));
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{
0.3, 0.25, 0.3, 0.25, 0.1, 0.8, 0.65, 0.7, 0.5, 0.1, 0.1, 0.4, 0.4, 0.35, 0.6};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(avgpool_dyn_auto_pad_test)
{
// Pooling with dynamic input, multidimensional kernel and auto-padding
migraphx::program p;
auto* mm = p.get_main_module();
auto s =
migraphx::shape{migraphx::shape::float_type, {{1, 1}, {1, 3}, {2, 6, {2}}, {2, 6, {2}}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(
migraphx::make_op("pooling",
{
{"mode", migraphx::op::pooling_mode::average},
{"dyn_global", false},
// non-default auto padding
{"padding_mode", migraphx::op::padding_mode_t::same_upper},
{"lengths", {2, 3}},
}),
x);
p.compile(migraphx::make_target("ref"));
std::vector<float> data{1, 2, 3, 4};
// * 1 2 * auto padding should look like this
// * 3 4 *
// * * * *
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 1, 2, 2}};
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.5, 2.5, 3.5, 3.5};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(avgpool_dyn_auto_pad_1d_test)
{
// Dynamic input with auto padding (== padding_mode specified)
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 3}, {3, 3}, {4, 4}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"lengths", {2}},
// padding added will be {1, 0} to make output
// the same size as input
{"padding_mode", migraphx::op::padding_mode_t::same_lower},
{"stride", {1}}}),
x);
p.compile(migraphx::make_target("ref"));
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
// clang-format off
std::vector<float> gold{0.3, 0.25, 0.3, 0.25,
0.8, 0.65, 0.7, 0.5,
0.1, 0.4, 0.4, 0.35};
// clang-format on
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(avgpool_dyn_pad_ceil_test)
{
// pooling with dynamic input and padding
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {1, 3}, {2, 4}, {2, 4}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"lengths", {2, 3}},
{"padding", {1, 2}},
{"ceil_mode", true},
{"stride", {1, 1}}}),
x);
p.compile(migraphx::make_target("ref"));
std::vector<float> data{1, 2, 3, 4};
// * * * * * *
// * * 1 2 * * padded input will look like this
// * * 3 4 * * but the * are ignored in averaging
// * * * * * *
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 1, 2, 2}};
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
// clang-format off
std::vector<float> gold{1.0, 1.5, 1.5, 2.0,
2.0, 2.5, 2.5, 3.0,
3.0, 3.5, 3.5, 4.0};
// clang-format on
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(avgpool_rank3_stride2_test)
{
// 1D case 2, stride 2
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {1};
op.stride = {2};
// clang-format off
std::vector<float> data{1.6321, -2.4186, 0.2239, -1.4232,
0.8158, 0.4103, -0.3149, -0.1361,
-0.3442, 2.007, 0.4331, 1.5295,
0.9965, 0.4766, 1.0942, -0.2915};
// clang-format on
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
// clang-format off
std::vector<float> gold{1.6321, -1.09735, -1.4232,
0.8158, 0.0477, -0.1361,
-0.3442, 1.22005, 1.5295,
0.9965, 0.7854, -0.2915};
// clang-format on
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(avgpool_rank5_test)
{
// 3D, input is 5D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2, 2, 2};
op.padding = {0, 0, 0};
op.stride = {1, 1, 1};
std::vector<float> data{
-0.179, -1.756, 0.651, 1.955, 1.87, -0.604, 0.247, 0.449, -0.137, 1.187, 1.593,
0.424, 2.698, -0.104, -0.069, -1.293, 0.538, 1.291, 0.974, 1.096, 0.74, -0.669,
-1.08, -1.041, -1.407, 1.43, -0.211, -0.017, 0.532, 1.276, 0.627, 0.236, -0.396,
-0.204, 0.501, -0.599, -1.414, -0.615, -0.274, 0.168, -0.144, 0.5, 1.42, 1.082,
-0.952, -0.846, -1.244, 1.475, 1.246, 1.344, -1.722, -1.24, -0.851, 0.06, 0.507,
0.762, -0.007, -1.484, 1.028, 0.317, 1.077, -1.289, 0.875, -0.417, -0.673, 1.715,
-0.307, 0.264, -0.973, 1.412, 2.561, -0.515, -0.201, 0.827, -1.231, 1.958, -0.552,
0.036, -0.993, -0.859, -1.458, -0.575, 0.048, -0.779, -1.025, -1.135, 1.166, -0.131,
0.726, 0.52, 0.467, -0.494, 0.675, 0.203, -0.63, -0.918, -0.5, -1.395, 1.39,
1.705, 0.444, -0.835, -0.506, 0.101, 0.602, 0.543, 0.357, 1.042};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{
0.908, 0.250625, 0.795, 0.40425, 0.711875, 0.194875, 0.014125, 0.09425,
-0.078375, 0.139375, 0.46075, 0.0285, -0.188125, -0.085, 0.378125, -0.085375,
-0.04, 0.304125, 0.40775, 0.2835, 0.112375, -0.073375, 0.4355, -0.187,
-0.392625, -0.258375, -0.485875, -0.0345, 0.16125, -0.131875, -0.228375, 0.068625};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(globalavgpool_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = s.lens();
op.lengths = {lens[2], lens[3]};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.575, 0.375};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(globalavgpool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 6}, {2, 6, {2}}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average}, {"dyn_global", true}}),
x);
p.compile(migraphx::make_target("ref"));
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 3, 2, 2}};
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.575, 0.375};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(globallppool_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm};
auto lens = s.lens();
op.lengths = {lens[2], lens[3]};
op.lp_order = 2;
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.5477225575051662, 1.307669683062202, 0.9327379053088815};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(globallppool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s =
migraphx::shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 6, {2}}, {2, 6, {2}}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::lpnorm}, {"dyn_global", true}}),
x);
p.compile(migraphx::make_target("ref"));
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 3, 2, 2}};
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.5477225575051662, 1.307669683062202, 0.9327379053088815};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(globalmaxpool_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
auto lens = s.lens();
op.lengths = {lens[2], lens[3]};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.9, 0.7};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(globalmaxpool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s =
migraphx::shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 6, {2}}, {2, 6, {2}}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {"dyn_global", true}}),
x);
p.compile(migraphx::make_target("ref"));
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 3, 2, 2}};
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.9, 0.7};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(lppool_l1_norm_test)
{
// L1 norm test
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.lp_order = 1;
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.5, 0.6, 0.5, 1.3, 1.4, 1.0, 0.8, 0.8, 0.7};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
// TODO: this tests compliance with a oneDNN rule and a feature that's commented out
// in pooling.hpp
// TEST_CASE(lppool_l1_norm_err_test)
// {
// // padding too large for kernel size
// migraphx::program p;
// auto* mm = p.get_main_module();
// auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 5}};
// auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm};
// op.lengths = {3};
// op.padding = {2};
// op.stride = {1};
// op.lp_order = 1;
// std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7};
// auto l0 = mm->add_literal(migraphx::literal{s, data});
// EXPECT(test::throws([&] {
// mm->add_instruction(op, l0);
// }));
// }
TEST_CASE(lppool_l2_norm_test)
{
// L2 norm test
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
op.lp_order = 2;
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.36055512754639896,
0.447213595499958,
0.4123105625617661,
0.9433981132056605,
1.0295630140987,
0.9055385138137417,
0.7071067811865475,
0.7071067811865475,
0.6082762530298219};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(lppool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::lpnorm},
{"lengths", {2}},
{"padding", {0}},
{"stride", {1}}}),
x);
p.compile(migraphx::make_target("ref"));
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.36055512754639896,
0.447213595499958,
0.4123105625617661,
0.9433981132056605,
1.0295630140987,
0.9055385138137417,
0.7071067811865475,
0.7071067811865475,
0.6082762530298219};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(maxpool_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> a = {
-2.1314404, -1.63041711, 1.54562736, 1.04625261, -1.42931843, -0.48703974, 0.4065806,
-0.1524526, 1.30775225, 0.45538983, -0.06631992, -1.75332725, 1.33493888, 0.47327688,
0.36873096, 1.18358743, -0.34640595, 1.22098756, 0.01946825, -0.20238149, 0.43348005,
-0.67991608, -0.83041084, 0.93537551, 0.70241445, -0.5654031, -1.30899191, -0.26735824,
-0.52444768, 1.99097753, 1.86504853, -0.26506025, 0.26236168, 0.43763575, 0.95300823,
-1.02733946, -0.74655169, -0.5374338, -0.28901565, -0.59789604, 0.5310151, 0.99125904,
0.40609556, -1.57175648, 0.22031412, 1.45862222, 0.53217483, 1.39087725, 1.00170159,
-0.87175864, -1.7204628, -1.72008383, -0.38656762, -0.01443311, 1.46645272, -1.39995027,
0.22505587, -0.43461126, -0.05511411, -0.79950953, -0.01439556, 0.08795211, 1.18943918,
-0.84079367, -1.73383629, -0.55662078, -0.30626822, -0.67339015, 0.44179603, 0.54316711,
0.40899998, -0.27831686, -1.11900508, -0.0881724, 0.35483059, 2.36277103, -0.04765317,
-0.36865309, 0.73814237, 1.47151589, 1.36546791, -0.32649881, -1.0517807, 2.24768877,
0.68883753, 0.58646208, -0.91017133, -0.50462508, -0.4013325, -0.72348958, -0.47368807,
0.35285577, -1.01817429, -0.5152272, 0.60321307, 0.43521205, -0.23733577, 0.66427642,
0.82949388, 0.82443929, 0.71550399, 0.34561086, 0.68570769, -0.40718508, -1.20350206,
0.15793853, -2.31013632, -0.07934658, -0.09348056, 0.36576006, 2.46601582, 0.11090943,
0.9144392, 0.56759721, -0.22112127, -0.21955389, 0.72474903, -1.28448462, 1.53285873,
0.37437943, 0.31409341, 1.95433736, 0.91620457, 0.86205518, 1.24365854, 0.19248386,
0.22526583, 0.13462132, -0.27561715, -2.06446075, -0.02306402, -1.38278747, 1.1411345,
1.31293464, -1.86041689, 1.06763375, -0.26541466, 1.4545635, 1.11430049, -0.66491818,
0.87101674, 0.67768967, -1.02062869, -1.05031872, -2.2764678, -2.0200038, 0.37592548,
-0.26701379, -0.83388507, 0.19403623, 1.00968623, 0.11020003, 1.16736257, -1.1160326,
0.47346735, 0.6126079, -0.19135755, 1.33624589, -0.29802522, -0.57873946, -1.06555879,
-0.20686582, 1.36892557, -0.19937795, 0.8649236, -1.40126073, 1.53441942, 0.34682792,
-1.31724346, -1.32898355, 2.40126371, 0.07845283, 1.35732043, -0.63678312, 0.39429256,
-1.36487007, -0.31026676, -0.44981545, -0.28994772, -0.14657612, -1.75206447, -0.70612341,
1.20071781, -1.64647579, -0.7133292, 0.88494766, 0.52119428, -2.77387547, 2.07681108,
-0.90133125, 0.2847338, 0.6174528, -0.20616426, -0.64263535, -1.08496261, 0.54275119,
-0.88503587, 0.6629802, 1.47319221, -1.05829155, -0.97027361, -0.93187737, -1.39954746,
-0.52359426, -0.14743951, 1.51522756, 0.2078452, -1.28156149, -1.19363916, -0.78680223,
-0.89094824, 1.30212069, -0.77974445, -0.58411664, 0.48764706, -0.67132682};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 6, 6}};
auto al = mm->add_literal(migraphx::literal{a_shape, a});
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {2, 2}},
{"lengths", {3, 2}}}),
al);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(36);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
1.33493888, 1.54562736, 1.22098756, 1.33493888, 1.18358743, 1.99097753,
1.00170159, 1.45862222, 1.39087725, 1.46645272, 1.18943918, -0.01443311,
1.47151589, 2.36277103, 2.24768877, 0.68883753, 0.82949388, 0.71550399,
1.95433736, 2.46601582, 1.53285873, 1.95433736, 1.06763375, 1.4545635,
1.33624589, 1.16736257, 0.6126079, 1.36892557, 2.40126371, 1.53441942,
0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(maxpool_pad_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> a = {-6, -5, -4, -3, -5, -1, 0, 1, 2, 3, 4, 5};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 2}};
auto al = mm->add_literal(migraphx::literal{a_shape, a});
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {1, 1}},
{"stride", {2, 2}},
{"lengths", {3, 2}}}),
al);
// * * * * * * * *
// * -6 -5 * * 0 1 *
// * -4 -3 * padding will look like this * 2 3 *
// * -5 -1 * and this * 4 5 *
// * * * * The * values are actually -INF * * * *
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(8);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-4, -3, -4, -1, 2, 3, 4, 5};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(maxpool_rank3_test0)
{
// 1D case 1, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.3, 0.4, 0.4, 0.8, 0.9, 0.9, 0.7, 0.7, 0.6};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(maxpool_rank3_test1)
{
// 1D case 2, input is 3D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {0};
op.stride = {2};
std::vector<float> data{0.4975, -0.1226, -0.0405, -0.2861, -0.1227, -0.6186, -0.9618,
0.6022, -0.1912, 1.1925, 0.5493, 0.1692, -0.8039, -1.0281,
0.9907, 0.477, 1.5001, -1.1603, -1.361, 1.2556};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4975, -0.0405, -0.6186, 0.6022, 0.5493, -0.8039, 1.5001, -1.1603};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(maxpool_rank3_ceil_test)
{
// 1D case 2, input is 3D, ceil mode
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {0};
op.stride = {2};
op.ceil_mode = true;
// clang-format off
std::vector<float> data{0.4975, -0.1226, -0.0405, -0.2861, -0.1227,
-0.6186, -0.9618, 0.6022, -0.1912, 1.1925,
0.5493, 0.1692, -0.8039, -1.0281, 0.9907,
0.477, 1.5001, -1.1603, -1.361, 1.2556};
// clang-format on
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
// clang-format off
std::vector<float> gold{0.4975, -0.0405, -0.1227, -0.6186,
0.6022, 1.1925, 0.5493, -0.8039,
0.9907, 1.5001, -1.1603, 1.2556};
// clang-format on
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(maxpool_rank5_test)
{
// 3D, input is 5D
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2, 2, 2};
op.padding = {0, 0, 0};
op.stride = {2, 2, 2};
std::vector<float> data{
-2.8029, 0.5861, 0.7015, 0.1297, -1.44, -1.9472, 0.7812, 2.408, -0.3145, 0.3405,
-0.9146, 0.0624, 1.5064, -0.8345, 1.7977, 1.8949, 1.0073, -0.2102, -0.042, -0.7146,
0.6227, -0.5263, -2.2598, 0.1713, 0.449, 0.5303, -0.8622, -0.5691, 0.907, -0.0569,
-1.5348, -0.4109, -0.1461, -0.5445, 0.4266, 0.2282, 1.3655, -2.1519, 0.6068, -0.2001,
-0.4702, 0.3864, 1.7083, 0.9096, 0.4286, -1.8866, 0.7034, 0.0293, 1.4587, 0.7672,
-2.8614, 0.8124, -0.053, 1.0449, 0.845, -0.0131, 0.1139, -0.859, -1.2681, -0.6337,
-0.4644, 0.1938, 0.2889, 0.9035, 0.7118, -0.5767, 0.4577, -0.0549, 0.2237, 0.5756,
0.0677, -0.0223, -0.329, 0.2364, 2.7666, -0.7417, -1.3196, -0.2655, 0.1698, -0.1777,
-0.9427, 2.6859, -0.7501, 0.5175, 1.0029, -2.6436, -0.4388, -1.2348, -0.1539, -0.6229,
-0.4136, 0.5085, 0.4136, -0.6439, -1.1953, -0.406, -0.0195, 0.1869, -0.8664, 1.1364,
0.5041, 0.0647, 0.1941, -1.0819, -0.4629, -0.5107, 0.3612, -0.3583};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(op, l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.5064, 1.3655, 0.9035, 2.6859};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(maxpool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"lengths", {2}},
{"padding", {0}},
{"stride", {1}}}),
x);
p.compile(migraphx::make_target("ref"));
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.3, 0.4, 0.4, 0.8, 0.9, 0.9, 0.7, 0.7, 0.6};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(pow_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
std::vector<float> data = {1, 2, 3};
auto b = mm->add_literal(migraphx::literal{s, data});
auto e = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("pow"), b, e);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(pow_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto b = mm->add_parameter("b", s);
auto e = mm->add_parameter("e", s);
mm->add_instruction(migraphx::make_op("pow"), b, e);
p.compile(migraphx::make_target("ref"));
std::vector<float> data = {1, 2, 3};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["b"] = migraphx::argument(input_fixed_shape0, data.data());
params0["e"] = migraphx::argument(input_fixed_shape0, data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(prefix_scan_sum_1d)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0, 3.0, 6.0, 10.0, 15.0, 21.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_dyn_1d)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{5, 8}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}),
input);
p.compile(migraphx::make_target("ref"));
std::vector<float> a = {1, 2, 3, 4, 5, 6};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {6}};
migraphx::parameter_map params0;
params0["X"] = migraphx::argument(input_fixed_shape0, a.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0, 3.0, 6.0, 10.0, 15.0, 21.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_2d_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_2d_2)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0, 3.0, 6.0, 1.0, 3.0, 6.0, 1.0, 3.0, 6.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_3d_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{
1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 2.0, 4.0, 6.0, 2.0, 4.0, 6.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_3d_2)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{
1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_3d_3)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 2}, {"exclusive", false}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{
1.0, 3.0, 6.0, 1.0, 3.0, 6.0, 1.0, 3.0, 6.0, 1.0, 3.0, 6.0, 1.0, 3.0, 6.0, 1.0, 3.0, 6.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_exclusive_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {8}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 1, 2, 3, 4}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", true}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.0, 1.0, 3.0, 6.0, 10.0, 11.0, 13.0, 16.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_exclusive_2)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", true}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{
0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 2.0, 4.0, 6.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_exclusive_reverse)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6}};
auto l0 = mm->add_literal(input);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", true}, {"reverse", true}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{20.0, 18.0, 15.0, 11.0, 6.0, 0.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_negative_axis_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", -3}, {"exclusive", false}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{
1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 2.0, 4.0, 6.0, 2.0, 4.0, 6.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_negative_axis_2)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", -2}, {"exclusive", false}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{
1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_negative_axis_3)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", -1}, {"exclusive", false}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{
1.0, 3.0, 6.0, 1.0, 3.0, 6.0, 1.0, 3.0, 6.0, 1.0, 3.0, 6.0, 1.0, 3.0, 6.0, 1.0, 3.0, 6.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_reverse_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {8}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 1, 2, 3, 4}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum",
{{"axis", 0}, {"exclusive", false}, {"reverse", true}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{20.0, 19.0, 17.0, 14.0, 10.0, 9.0, 7.0, 4.0};
EXPECT(results_vector == gold);
}
TEST_CASE(prefix_scan_sum_reverse_2)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 1, 2, 3, 4}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("prefix_scan_sum",
{{"axis", 0}, {"exclusive", false}, {"reverse", true}}),
l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.0, 4.0, 6.0, 8.0, 1.0, 2.0, 3.0, 4.0};
EXPECT(results_vector == gold);
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(prelu_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_literal(migraphx::literal{s, {-1, 0, 2}});
auto slope = mm->add_literal(migraphx::literal{s, {2, 1, 2}});
mm->add_instruction(migraphx::make_op("prelu"), x, slope);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2.0f, 0.0f, 2.0f};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(prelu_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto slope = mm->add_parameter("slope", s);
mm->add_instruction(migraphx::make_op("prelu"), x, slope);
p.compile(migraphx::make_target("ref"));
std::vector<float> x_data{-1, 0, 2};
std::vector<float> slope_data{2, 1, 2};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
params0["slope"] = migraphx::argument(input_fixed_shape0, slope_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2.0f, 0.0f, 2.0f};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(quant_conv2d_padding_stride_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = mm->add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = mm->add_literal(migraphx::literal{c_shape, c});
mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}), al, cl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int32_t> gold = {4521,
7014,
7830,
11952,
10515,
16734,
19737,
30906,
13161,
19542,
19494,
28800,
34707,
52590,
54729,
82746};
std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(quant_conv2d_padding_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = mm->add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = mm->add_literal(migraphx::literal{c_shape, c});
mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}), al, cl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int32_t> gold = {
4521, 6753, 7014, 4635, 6858, 10197, 10548, 6939, 7830, 11601, 11952, 7839, 5007,
7383, 7590, 4953, 10515, 15987, 16734, 11277, 16821, 25506, 26586, 17874, 19737, 29826,
30906, 20718, 13593, 20505, 21198, 14187, 13161, 19281, 19542, 12699, 18522, 27045, 27396,
17739, 19494, 28449, 28800, 18639, 11919, 17319, 17526, 11289, 34707, 51843, 52590, 34893,
51813, 77346, 78426, 52002, 54729, 81666, 82746, 54846, 36057, 53769, 54462, 36075};
std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(quant_conv2d_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
std::vector<int8_t> a(2 * 3 * 4 * 4);
std::iota(a.begin(), a.end(), 0);
auto al = mm->add_literal(migraphx::literal{a_shape, a});
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
std::vector<int8_t> c(2 * 3 * 3 * 3);
std::iota(c.begin(), c.end(), 0);
auto cl = mm->add_literal(migraphx::literal{c_shape, c});
mm->add_instruction(migraphx::make_op("quant_convolution"), al, cl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int32_t> gold = {10197,
10548,
11601,
11952,
25506,
26586,
29826,
30906,
27045,
27396,
28449,
28800,
77346,
78426,
81666,
82746};
std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(quantizelinear_1)
{
migraphx::shape xs{migraphx::shape::float_type, {2, 3, 3}};
std::vector<float> xv = {
-300, 600, 129, -1000, 4, 3, -6, 600, 550, -300, 600, 129, -1000, 4, 3, -6, 600, 550};
migraphx::shape ss{migraphx::shape::float_type, {2, 3, 3}};
std::vector<float> sv = {2, 2, 2, 4, 4, 4, 6, 6, 6, 2, 2, 2, 4, 4, 4, 6, 6, 6};
migraphx::shape zs{migraphx::shape::int8_type, {2, 3, 3}};
std::vector<uint8_t> zv = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
auto create_program = [&]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_literal(xs, xv);
auto s = mm->add_literal(ss, sv);
auto z = mm->add_literal(zs, zv);
mm->add_instruction(migraphx::make_op("quantizelinear"), x, s, z);
return p;
};
migraphx::program p1 = create_program();
p1.compile(migraphx::make_target("ref"));
auto result = p1.eval({}).back();
std::vector<float> results_vector(18);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{
-128, 127, 65, -128, 1, 1, -1, 100, 92, -128, 127, 65, -128, 1, 1, -1, 100, 92};
EXPECT(results_vector == gold);
}
TEST_CASE(quantizelinear_2)
{
migraphx::shape xs{migraphx::shape::float_type, {2, 3, 3}};
std::vector<float> xv = {
-300, 600, 129, -1000, 4, 3, -6, 600, 550, -300, 600, 129, -1000, 4, 3, -6, 600, 550};
migraphx::shape ss{migraphx::shape::float_type, {2, 3, 3}};
std::vector<float> sv = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
auto create_program = [&]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_literal(xs, xv);
auto s = mm->add_literal(ss, sv);
mm->add_instruction(migraphx::make_op("quantizelinear"), x, s);
return p;
};
migraphx::program p1 = create_program();
p1.compile(migraphx::make_target("ref"));
auto result = p1.eval({}).back();
std::vector<float> results_vector(18);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 255, 65, 0, 2, 2, 0, 255, 255, 0, 255, 65, 0, 2, 2, 0, 255, 255};
EXPECT(results_vector == gold);
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <random>
#include <test.hpp>
/**
* Reference test for the random_seed operation
*/
TEST_CASE(random_seed_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_instruction(migraphx::make_op("random_seed"));
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<uint64_t> result_vec1(1);
result.visit([&](auto output) { result_vec1.assign(output.begin(), output.end()); });
std::vector<uint64_t> result_vec2(1);
// Identical calls should give different seeds every time with 1/(2^64) chance of a repeat.
// We don't analyze for true randomness.
result = p.eval({}).back();
result.visit([&](auto output) { result_vec2.assign(output.begin(), output.end()); });
EXPECT(result_vec1[0] != result_vec2[0]);
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <random>
#include <test.hpp>
/**
* Reference test for the random_uniform operation. Also invokes the random_seed operation.
*/
TEST_CASE(random_uniform_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
uint64_t seed(0);
size_t sample_size(200);
// Shape of the random data
migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}};
// data tensor must be allocated at this point but does not need to be initialized.
std::vector<float> data(sample_size);
auto input = mm->add_literal(migraphx::literal(rs, data));
// Runtime randomization seed
migraphx::shape seed_shape{migraphx::shape::uint64_type, {1}};
std::vector<uint64_t> seed_data{seed};
auto seed_input = mm->add_literal(migraphx::literal(seed_shape, seed_data));
mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, input);
p.compile(migraphx::make_target("ref"));
// no params_map needed
auto result = p.eval({}).back();
std::vector<float> result_vec(sample_size);
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
// Compare result with the STL's mt19937 generator
std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range_with_tolerance(result_vec,
migraphx::verify::expected{rand_samples},
migraphx::verify::tolerance{0.00001}));
}
TEST_CASE(random_uniform_int_test)
{
// random uniform distribution with an integer type input shape
migraphx::program p;
auto* mm = p.get_main_module();
float seed(0.1);
size_t sample_size(200);
// Shape of the random data
migraphx::shape rs{migraphx::shape::uint16_type, {1, sample_size}};
// data tensor must be allocated at this point but does not need to be initialized.
std::vector<uint16_t> data(sample_size);
auto input = mm->add_literal(migraphx::literal(rs, data));
// Runtime randomization seed
migraphx::shape seed_shape{migraphx::shape::float_type, {1}};
std::vector<float> seed_data{seed};
auto seed_input = mm->add_literal(migraphx::literal(seed_shape, seed_data));
mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, input);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params0;
auto result = p.eval(params0).back();
std::vector<uint16_t> result_vec(sample_size);
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
// Compare result with the STL's mt19937 generator
std::mt19937 gen(seed);
std::uniform_int_distribution<uint16_t> dis;
std::vector<uint16_t> gold_rand_samples(sample_size);
std::generate(gold_rand_samples.begin(), gold_rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_rms_range(result_vec, gold_rand_samples));
}
TEST_CASE(random_uniform_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
uint64_t seed(17);
size_t sample_size(200);
// Shape of the random data
migraphx::shape rs{migraphx::shape::float_type, {{1, 2}, {2, sample_size + 1}}};
auto input = mm->add_parameter("Input_1", rs);
// Runtime randomization seed
migraphx::shape seed_shape{migraphx::shape::uint64_type, {1}};
auto seed_input = mm->add_parameter("Seed", seed_shape);
mm->add_instruction(migraphx::make_op("random_uniform", {}), seed_input, input);
p.compile(migraphx::make_target("ref"));
// Create a dummy input to hold the random data
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {sample_size}};
migraphx::parameter_map params0;
params0["Input_1"] = migraphx::argument(input_fixed_shape1);
std::vector<uint64_t> seed_data = {seed};
params0["Seed"] = migraphx::argument(seed_shape, seed_data.data());
auto result = p.eval(params0).back();
std::vector<float> result_vec(sample_size);
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
// Compare result with the STL's mt19937 generator
std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> gold_rand_samples(sample_size);
std::generate(gold_rand_samples.begin(), gold_rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_rms_range(result_vec, gold_rand_samples));
}
TEST_CASE(random_uniform_and_seed_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
size_t sample_size(20000);
// Shape of the random data
migraphx::shape rs{migraphx::shape::float_type, {{1, 2}, {2, sample_size + 1}}};
auto input = mm->add_parameter("Input_1", rs);
// Runtime randomization seed
auto seed_input = mm->add_instruction(migraphx::make_op("random_seed"));
mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, input);
p.compile(migraphx::make_target("ref"));
// Create a dummy input to hold the random data
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {sample_size}};
migraphx::parameter_map params0;
params0["Input_1"] = migraphx::argument(input_fixed_shape1);
auto result = p.eval(params0).back();
result.visit([&](auto output) { EXPECT(output.size() == sample_size); });
// Do not check the content of the data since it's not repeatable
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(recip_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {3}};
std::vector<float> data{-0.5f, 0.1f, 0.5f};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("recip"), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2.0f, 10.0f, 2.0f};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(recip_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("recip"), input);
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data{-0.5f, 0.1f, 0.5f};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2.0f, 10.0f, 2.0f};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(reduce_max_axis0)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{9, 10, 11, 12};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_max_dynamic_axis0)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2}}, {3, 5, {3}}}};
auto input = mm->add_parameter("X", s);
auto reduce_max_op = migraphx::make_op("reduce_max", {{"axes", {0}}});
mm->add_instruction(reduce_max_op, input);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 5}};
std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
params["X"] = migraphx::argument(input_fixed_shape, input_data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {6, 7, 8, 9, 10};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reduce_max_axis01)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0, 1}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{11, 12};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_max_axis02)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0, 2}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{10, 12};
EXPECT(results_vector == gold);
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(reduce_mean_axis02)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 2}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5.5, 7.5};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2, 3, 6, 7, 10, 11};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis12)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1, 2}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.5f, 6.5f, 10.5f};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis2)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.5f, 3.5f, 5.5f, 7.5f, 9.5f, 11.5f};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_int)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1, 2}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int> gold{2, 6, 10};
EXPECT(results_vector == gold);
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(reduce_min_axis02)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", {0, 2}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 3};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_min_axis1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", {1}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 5, 6, 9, 10};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_min_axis12)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", {1, 2}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 5, 9};
EXPECT(results_vector == gold);
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(reduce_prod_axis0)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 3, 2, 3}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_prod", {{"axes", {0}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{6, 18, 12, 18};
EXPECT(results_vector == gold);
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(reduce_sum_axis0)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{15, 18, 21, 24};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_axis02)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{33, 45};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_axis1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{4, 6, 12, 14, 20, 22};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_axis12)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1, 2}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{10, 26, 42};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_axis2)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = mm->add_literal(input);
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{3, 7, 11, 15, 19, 23};
EXPECT(results_vector == gold);
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(relu_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {-1.f, 0.f, 1.f}});
mm->add_instruction(migraphx::make_op("relu"), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.f, 0.f, 1.f};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(relu_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("relu"), input);
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data{-1.f, 0.f, 1.f};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.f, 0.f, 1.f};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(reshape_lazy_test0)
{
migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> data(24);
std::iota(data.begin(), data.end(), -3);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> new_shape = {8, 3, 1, 1};
mm->add_instruction(migraphx::make_op("reshape_lazy", {{"dims", new_shape}}), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, data));
}
TEST_CASE(reshape_lazy_test1)
{
migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> data(24);
std::iota(data.begin(), data.end(), -3);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> new_shape = {1, 3, 4, 2};
mm->add_instruction(migraphx::make_op("reshape_lazy", {{"dims", new_shape}}), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, data));
}
TEST_CASE(reshape_lazy_test2)
{
migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> data(24);
std::iota(data.begin(), data.end(), -3);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> new_shape = {1, 2, 3, 4};
mm->add_instruction(migraphx::make_op("reshape_lazy", {{"dims", new_shape}}), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, data));
}
TEST_CASE(reshape_lazy_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
std::vector<int64_t> new_shape = {0, 8, 3, 1};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("reshape_lazy", {{"dims", new_shape}}), input);
p.compile(migraphx::make_target("ref"));
std::vector<float> data(48);
std::iota(data.begin(), data.end(), -3);
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 24, 1, 1}};
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, data));
}
TEST_CASE(reshape_test0)
{
migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> gold(24);
std::iota(gold.begin(), gold.end(), -3);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, gold});
std::vector<int64_t> new_shape = {8, 3, 1, 1};
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_test1)
{
migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> gold(24);
std::iota(gold.begin(), gold.end(), -3);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, gold});
std::vector<int64_t> new_shape = {1, 3, 4, 2};
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_test2)
{
migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> gold(24);
std::iota(gold.begin(), gold.end(), -3);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, gold});
std::vector<int64_t> new_shape = {1, 2, 3, 4};
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_dyn_1in_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
std::vector<int64_t> new_shape = {0, 8, 3, 1};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
p.compile(migraphx::make_target("ref"));
std::vector<float> gold(48);
std::iota(gold.begin(), gold.end(), -3);
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 24, 1, 1}};
params["X"] = migraphx::argument(input_fixed_shape, gold.data());
auto result = p.eval(params).back();
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_2in_test0)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_in{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
migraphx::shape s_out{migraphx::shape::float_type, {{1, 4}, {6, 6}, {4, 4}, {1, 1}}};
auto input = mm->add_parameter("X", s_in);
auto output_buffer = mm->add_parameter("Y", s_out);
mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer);
p.compile(migraphx::make_target("ref"));
std::vector<float> gold(48);
std::iota(gold.begin(), gold.end(), -3.);
std::vector<float> buffer(48);
std::iota(buffer.begin(), buffer.end(), 0.);
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 4, 1}};
params["X"] = migraphx::argument(input_fixed_shape, gold.data());
params["Y"] = migraphx::argument(output_fixed_shape, buffer.data());
auto result = p.eval(params).back();
EXPECT(result.get_shape() == output_fixed_shape);
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_2in_test1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_in{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape s_out{migraphx::shape::float_type, {{2, 4}, {6, 6}, {2, 4}, {1, 1}}};
auto input = mm->add_parameter("X", s_in);
auto output_buffer = mm->add_parameter("Y", s_out);
mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer);
p.compile(migraphx::make_target("ref"));
std::vector<float> gold(48);
std::iota(gold.begin(), gold.end(), -3.);
std::vector<float> buffer(48);
std::iota(buffer.begin(), buffer.end(), 0.);
migraphx::parameter_map params;
migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 4, 1}};
params["X"] = migraphx::argument(s_in, gold.data());
params["Y"] = migraphx::argument(output_fixed_shape, buffer.data());
auto result = p.eval(params).back();
EXPECT(result.get_shape() == output_fixed_shape);
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_2in_elements_runtime_error)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_in{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape s_out{migraphx::shape::float_type, {{2, 4}, {6, 6}, {2, 4}, {1, 1}}};
auto input = mm->add_parameter("X", s_in);
auto output_buffer = mm->add_parameter("Y", s_out);
mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer);
p.compile(migraphx::make_target("ref"));
std::vector<float> gold(48);
std::iota(gold.begin(), gold.end(), -3.);
std::vector<float> buffer(48);
std::iota(buffer.begin(), buffer.end(), 0.);
migraphx::parameter_map params;
// elements do not match up
migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 2, 1}};
params["X"] = migraphx::argument(s_in, gold.data());
params["Y"] = migraphx::argument(output_fixed_shape, buffer.data());
EXPECT(test::throws([&] { std::ignore = p.eval(params).back(); }));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(reverse_test_axis0)
{
migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}};
std::vector<float> data(32);
std::iota(data.begin(), data.end(), 1);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{in_shape, data});
std::vector<int> axes = {0};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::swap_ranges(gold.begin(), gold.begin() + 16, gold.begin() + 16);
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reverse_test_axis1)
{
migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}};
std::vector<float> data(32);
std::iota(data.begin(), data.end(), 1);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{in_shape, data});
std::vector<int> axes = {1};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::reverse(gold.begin(), gold.begin() + 16);
std::reverse(gold.end() - 16, gold.end());
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reverse_test_axis10)
{
migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}};
std::vector<float> data(32);
std::iota(data.begin(), data.end(), 1);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{in_shape, data});
std::vector<int> axes = {1, 0};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::reverse(gold.begin(), gold.begin() + 16);
std::reverse(gold.end() - 16, gold.end());
std::swap_ranges(gold.begin(), gold.begin() + 16, gold.begin() + 16);
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
...@@ -21,18 +21,13 @@ ...@@ -21,18 +21,13 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <iostream> #include <migraphx/instruction.hpp>
#include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/serialize.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -150,8 +145,8 @@ TEST_CASE(rnn_forward) ...@@ -150,8 +145,8 @@ TEST_CASE(rnn_forward)
-0.16477929, -0.16477929,
-0.11893477}; -0.11893477};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
} }
{ {
...@@ -211,8 +206,8 @@ TEST_CASE(rnn_forward) ...@@ -211,8 +206,8 @@ TEST_CASE(rnn_forward)
0.44193283, 0.44193283,
-0.16477929, -0.16477929,
-0.11893477}; -0.11893477};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
{ {
...@@ -271,8 +266,8 @@ TEST_CASE(rnn_forward) ...@@ -271,8 +266,8 @@ TEST_CASE(rnn_forward)
0}; 0};
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
0.034457, 0.191679, -0.394683, -0.308897, -0.371446, 0.317082, 0.131042, -0.18736}; 0.034457, 0.191679, -0.394683, -0.308897, -0.371446, 0.317082, 0.131042, -0.18736};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 3 args // 3 args
...@@ -302,7 +297,7 @@ TEST_CASE(rnn_forward) ...@@ -302,7 +297,7 @@ TEST_CASE(rnn_forward)
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.}; 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// seq_len = 1 // seq_len = 1
...@@ -349,7 +344,7 @@ TEST_CASE(rnn_forward) ...@@ -349,7 +344,7 @@ TEST_CASE(rnn_forward)
0.31708236, 0.31708236,
0.13104209, 0.13104209,
-0.18736027}; -0.18736027};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -400,7 +395,6 @@ TEST_CASE(rnn_reverse) ...@@ -400,7 +395,6 @@ TEST_CASE(rnn_reverse)
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
// concatenation of hidden states as program output // concatenation of hidden states as program output
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input}); auto seq = mm->add_literal(migraphx::literal{in_shape, input});
...@@ -444,7 +438,7 @@ TEST_CASE(rnn_reverse) ...@@ -444,7 +438,7 @@ TEST_CASE(rnn_reverse)
0.46251031, 0.46251031,
-0.20639211, -0.20639211,
0.37488942}; 0.37488942};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// rnn last output as program output // rnn last output as program output
...@@ -487,7 +481,7 @@ TEST_CASE(rnn_reverse) ...@@ -487,7 +481,7 @@ TEST_CASE(rnn_reverse)
0.44124447, 0.44124447,
0.14365635, 0.14365635,
0.14803654}; 0.14803654};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// rnn hidden states and last hidden state output as program outputs // rnn hidden states and last hidden state output as program outputs
...@@ -550,8 +544,8 @@ TEST_CASE(rnn_reverse) ...@@ -550,8 +544,8 @@ TEST_CASE(rnn_reverse)
0.14365635, 0.14365635,
0.14803654}; 0.14803654};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// rnn hidden states and last hidden state output as program outputs // rnn hidden states and last hidden state output as program outputs
...@@ -612,8 +606,8 @@ TEST_CASE(rnn_reverse) ...@@ -612,8 +606,8 @@ TEST_CASE(rnn_reverse)
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
-0.293853, 0.167968, 0.51076, 0.402587, -0.0070999, 0.46251, -0.206392, 0.374889}; -0.293853, 0.167968, 0.51076, 0.402587, -0.0070999, 0.46251, -0.206392, 0.374889};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
} }
...@@ -724,8 +718,8 @@ TEST_CASE(rnn_bidirectional) ...@@ -724,8 +718,8 @@ TEST_CASE(rnn_bidirectional)
0.14365635, 0.14365635,
0.14803654}; 0.14803654};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// last rnn output for program output // last rnn output for program output
...@@ -790,8 +784,8 @@ TEST_CASE(rnn_bidirectional) ...@@ -790,8 +784,8 @@ TEST_CASE(rnn_bidirectional)
0.143656, 0.143656,
0.148037}; 0.148037};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// 4 args // 4 args
...@@ -841,7 +835,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -841,7 +835,7 @@ TEST_CASE(rnn_bidirectional)
0.14365635, 0.14365635,
0.14803654}; 0.14803654};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// 3 args // 3 args
...@@ -876,7 +870,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -876,7 +870,7 @@ TEST_CASE(rnn_bidirectional)
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0., 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0.}; 0., 0., 0., 0., 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// concatenation of hidden state for program output // concatenation of hidden state for program output
...@@ -929,7 +923,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -929,7 +923,7 @@ TEST_CASE(rnn_bidirectional)
-0.20639211, -0.20639211,
0.37488942}; 0.37488942};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -1014,7 +1008,10 @@ TEST_CASE(rnn_fp16) ...@@ -1014,7 +1008,10 @@ TEST_CASE(rnn_fp16)
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.}; 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold, 5e4)); EXPECT(migraphx::verify::verify_range_with_tolerance(
last_output_data,
migraphx::verify::expected{last_output_data_gold},
migraphx::verify::tolerance{0.005}));
} }
TEST_CASE(gru_forward) TEST_CASE(gru_forward)
...@@ -1112,7 +1109,7 @@ TEST_CASE(gru_forward) ...@@ -1112,7 +1109,7 @@ TEST_CASE(gru_forward)
0.48523626, 0.60002893, -0.3969709, 0.43360898, 0.35775262, 0.23280787, 0.48523626, 0.60002893, -0.3969709, 0.43360898, 0.35775262, 0.23280787,
-0.52179873, -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; -0.52179873, -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// last output for output // last output for output
...@@ -1158,7 +1155,7 @@ TEST_CASE(gru_forward) ...@@ -1158,7 +1155,7 @@ TEST_CASE(gru_forward)
0.51757574, 0.51757574,
0.50380427}; 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// two rnn_last_hs_output operators after gru // two rnn_last_hs_output operators after gru
...@@ -1205,7 +1202,7 @@ TEST_CASE(gru_forward) ...@@ -1205,7 +1202,7 @@ TEST_CASE(gru_forward)
0.51757574, 0.51757574,
0.50380427}; 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// last output for output, linear_before_reset = 0 // last output for output, linear_before_reset = 0
...@@ -1251,7 +1248,7 @@ TEST_CASE(gru_forward) ...@@ -1251,7 +1248,7 @@ TEST_CASE(gru_forward)
0.6014447, 0.6014447,
0.43445644}; 0.43445644};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -1336,7 +1333,7 @@ TEST_CASE(gru_forward_args) ...@@ -1336,7 +1333,7 @@ TEST_CASE(gru_forward_args)
-0.232523, 0.00214573, 0.231693, -0.160475, -0.518952, -0.232523, 0.00214573, 0.231693, -0.160475, -0.518952,
0.0467166, 0.12327, -0.374162, 0.137778, 0.251976}; 0.0467166, 0.12327, -0.374162, 0.137778, 0.251976};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 4 args (bias is used) // 4 args (bias is used)
...@@ -1379,7 +1376,7 @@ TEST_CASE(gru_forward_args) ...@@ -1379,7 +1376,7 @@ TEST_CASE(gru_forward_args)
-0.416866, 0.377186, 0.32922, 0.162214, -0.519973, -0.416866, 0.377186, 0.32922, 0.162214, -0.519973,
-0.140072, 0.465076, -0.229563, 0.500164, 0.195166}; -0.140072, 0.465076, -0.229563, 0.500164, 0.195166};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 4 args (ih is used) // 4 args (ih is used)
...@@ -1423,7 +1420,7 @@ TEST_CASE(gru_forward_args) ...@@ -1423,7 +1420,7 @@ TEST_CASE(gru_forward_args)
-0.197, 0.0885705, 0.269396, -0.0414511, -0.515137, -0.197, 0.0885705, 0.269396, -0.0414511, -0.515137,
-0.03075, 0.158326, -0.296488, 0.177983, 0.519498}; -0.03075, 0.158326, -0.296488, 0.177983, 0.519498};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -1525,7 +1522,7 @@ TEST_CASE(gru_forward_actv_funcs) ...@@ -1525,7 +1522,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.51757574, 0.51757574,
0.50380427}; 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 1 activation function (sigmoid) specified // 1 activation function (sigmoid) specified
...@@ -1566,7 +1563,7 @@ TEST_CASE(gru_forward_actv_funcs) ...@@ -1566,7 +1563,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.35652235, 0.6033026, 0.52634895, 0.5815402, 0.3001663, 0.35652235, 0.6033026, 0.52634895, 0.5815402, 0.3001663,
0.39814138, 0.4354002, 0.4310627, 0.6708563, 0.7509278}; 0.39814138, 0.4354002, 0.4310627, 0.6708563, 0.7509278};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 1 activation function (tanh) specified // 1 activation function (tanh) specified
...@@ -1611,7 +1608,7 @@ TEST_CASE(gru_forward_actv_funcs) ...@@ -1611,7 +1608,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.65615714, 0.65615714,
0.53612584}; 0.53612584};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// seq length of 1 // seq length of 1
...@@ -1661,7 +1658,7 @@ TEST_CASE(gru_forward_actv_funcs) ...@@ -1661,7 +1658,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.6104771, 0.6104771,
0.79759157}; 0.79759157};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -1736,12 +1733,12 @@ TEST_CASE(gru_reverse) ...@@ -1736,12 +1733,12 @@ TEST_CASE(gru_reverse)
migraphx::make_op( migraphx::make_op(
"gru", "gru",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"), migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})}, migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}, {"clip", clip},
{"linear_before_reset", 1}}), {"linear_before_reset", 1}}),
seq, seq,
w, w,
r, r,
...@@ -1777,8 +1774,8 @@ TEST_CASE(gru_reverse) ...@@ -1777,8 +1774,8 @@ TEST_CASE(gru_reverse)
0.55703, 0.55703,
0.54711}; 0.54711};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
} }
// variable input sequence length // variable input sequence length
...@@ -1838,8 +1835,8 @@ TEST_CASE(gru_reverse) ...@@ -1838,8 +1835,8 @@ TEST_CASE(gru_reverse)
0.558397, 0.558397,
0.664423}; 0.664423};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
} }
// last output for output, linear_before_reset = 0 // last output for output, linear_before_reset = 0
...@@ -1885,7 +1882,7 @@ TEST_CASE(gru_reverse) ...@@ -1885,7 +1882,7 @@ TEST_CASE(gru_reverse)
0.646604, 0.646604,
0.463943}; 0.463943};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// no activation function specified, so default is used. // no activation function specified, so default is used.
...@@ -1924,7 +1921,7 @@ TEST_CASE(gru_reverse) ...@@ -1924,7 +1921,7 @@ TEST_CASE(gru_reverse)
-0.329512, 0.476095, 0.284044, 0.392077, -0.369226, -0.329512, 0.476095, 0.284044, 0.392077, -0.369226,
-0.3275, -0.027301, 0.143774, 0.655686, 0.782831}; -0.3275, -0.027301, 0.143774, 0.655686, 0.782831};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// seq length of 1 // seq length of 1
...@@ -1974,7 +1971,7 @@ TEST_CASE(gru_reverse) ...@@ -1974,7 +1971,7 @@ TEST_CASE(gru_reverse)
0.610477, 0.610477,
0.797592}; 0.797592};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -2067,12 +2064,12 @@ TEST_CASE(gru_bidirectional) ...@@ -2067,12 +2064,12 @@ TEST_CASE(gru_bidirectional)
migraphx::make_op( migraphx::make_op(
"gru", "gru",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"), migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})}, migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}, {"clip", clip},
{"linear_before_reset", 1}}), {"linear_before_reset", 1}}),
seq, seq,
w, w,
r, r,
...@@ -2105,8 +2102,8 @@ TEST_CASE(gru_bidirectional) ...@@ -2105,8 +2102,8 @@ TEST_CASE(gru_bidirectional)
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229}; 0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
} }
// same input sequence length, but shorter than max squence length // same input sequence length, but shorter than max squence length
...@@ -2174,8 +2171,8 @@ TEST_CASE(gru_bidirectional) ...@@ -2174,8 +2171,8 @@ TEST_CASE(gru_bidirectional)
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229}; 0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
} }
// variable input sequence lengths // variable input sequence lengths
...@@ -2233,8 +2230,8 @@ TEST_CASE(gru_bidirectional) ...@@ -2233,8 +2230,8 @@ TEST_CASE(gru_bidirectional)
-0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078, -0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078,
0.182457, 0.304506, 0.313825, 0.397697, 0.300873}; 0.182457, 0.304506, 0.313825, 0.397697, 0.300873};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
} }
// last output for output, linear_before_reset = 0 // last output for output, linear_before_reset = 0
...@@ -2274,7 +2271,7 @@ TEST_CASE(gru_bidirectional) ...@@ -2274,7 +2271,7 @@ TEST_CASE(gru_bidirectional)
-0.10688055, -0.4767866, 0.6317833, 0.00286336, 0.53692746, -0.00617076, 0.04564289, -0.10688055, -0.4767866, 0.6317833, 0.00286336, 0.53692746, -0.00617076, 0.04564289,
-0.18030001, 0.39584228, 0.53879917, 0.384983, 0.2759448, 0.11611474}; -0.18030001, 0.39584228, 0.53879917, 0.384983, 0.2759448, 0.11611474};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -2376,7 +2373,7 @@ TEST_CASE(gru_bidirectional_args) ...@@ -2376,7 +2373,7 @@ TEST_CASE(gru_bidirectional_args)
0.469122, -0.306578, -0.221095, -0.106449, -0.248934, -0.00682121, 0.288407, 0.469122, -0.306578, -0.221095, -0.106449, -0.248934, -0.00682121, 0.288407,
0.198708, 0.0695644, 0.211621, 0.00246037}; 0.198708, 0.0695644, 0.211621, 0.00246037};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 4 args (bias is used) // 4 args (bias is used)
...@@ -2427,7 +2424,7 @@ TEST_CASE(gru_bidirectional_args) ...@@ -2427,7 +2424,7 @@ TEST_CASE(gru_bidirectional_args)
0.476508, -0.313413, -0.0361821, -0.173037, -0.235731, -0.163113, 0.349008, 0.476508, -0.313413, -0.0361821, -0.173037, -0.235731, -0.163113, 0.349008,
0.248674, -0.0295413, 0.291437, -0.165005}; 0.248674, -0.0295413, 0.291437, -0.165005};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 4 args (ih is used) // 4 args (ih is used)
...@@ -2475,7 +2472,7 @@ TEST_CASE(gru_bidirectional_args) ...@@ -2475,7 +2472,7 @@ TEST_CASE(gru_bidirectional_args)
0.233106, 0.32996, -0.17175, 0.0190231, -0.154805, -0.205631, -0.405354, 0.233106, 0.32996, -0.17175, 0.0190231, -0.154805, -0.205631, -0.405354,
0.519054, -0.380409, -0.0350301, -0.00633752, 0.403791, 0.181883, -0.0977917, 0.519054, -0.380409, -0.0350301, -0.00633752, 0.403791, 0.181883, -0.0977917,
-0.0339407, 0.413089, 0.721238, 0.431879}; -0.0339407, 0.413089, 0.721238, 0.431879};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -2589,7 +2586,7 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -2589,7 +2586,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229}; 0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 1 activation function (sigmoid) specified // 1 activation function (sigmoid) specified
...@@ -2632,7 +2629,7 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -2632,7 +2629,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.463795, 0.539649, 0.487682, 0.554471, 0.395916, 0.430744, 0.415923, 0.424275, 0.463795, 0.539649, 0.487682, 0.554471, 0.395916, 0.430744, 0.415923, 0.424275,
0.409655, 0.698256, 0.126883, 0.554374, 0.216137, 0.671491, 0.263833, 0.0678646, 0.409655, 0.698256, 0.126883, 0.554374, 0.216137, 0.671491, 0.263833, 0.0678646,
0.132732, 0.477083, 0.802206, 0.626802}; 0.132732, 0.477083, 0.802206, 0.626802};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 1 activation function (tanh) specified // 1 activation function (tanh) specified
...@@ -2676,7 +2673,7 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -2676,7 +2673,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.66716, -0.704461, -0.393346, -0.627123, 0.210395, 0.0563026, 0.31419, 0.66716, -0.704461, -0.393346, -0.627123, 0.210395, 0.0563026, 0.31419,
0.759629, 0.000258222, 0.350835, -0.682684}; 0.759629, 0.000258222, 0.350835, -0.682684};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 3 activation functions specified // 3 activation functions specified
...@@ -2716,7 +2713,7 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -2716,7 +2713,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
1.15142, 0.457633, 0.300962, 0.361245, 0.666199, 1.15142, 0.457633, 0.300962, 0.361245, 0.666199,
0.330446, 0.301982, -0.443763, -0.0655817, -0.326473, 0.330446, 0.301982, -0.443763, -0.0655817, -0.326473,
0.861394, 0.560799, -0.101768, 0.145142, 0.128956}; 0.861394, 0.560799, -0.101768, 0.145142, 0.128956};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 4 activation functions all specified // 4 activation functions all specified
...@@ -2764,7 +2761,7 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -2764,7 +2761,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665, 0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665,
0.079043, 0.322652, 0.752701, 0.243775}; 0.079043, 0.322652, 0.752701, 0.243775};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -2879,7 +2876,7 @@ TEST_CASE(gru_bidirectional_seq_1) ...@@ -2879,7 +2876,7 @@ TEST_CASE(gru_bidirectional_seq_1)
-0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078, -0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078,
-0.144492, -0.0115366, 0.409153, 0.487015, 0.550755}; -0.144492, -0.0115366, 0.409153, 0.487015, 0.550755};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
TEST_CASE(gru_fp16) TEST_CASE(gru_fp16)
...@@ -2989,7 +2986,8 @@ TEST_CASE(gru_fp16) ...@@ -2989,7 +2986,8 @@ TEST_CASE(gru_fp16)
-0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873, -0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873,
-0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold, 5e4)); EXPECT(migraphx::verify::verify_range_with_tolerance(
hs_data, migraphx::verify::expected{hs_data_gold}, migraphx::verify::tolerance{0.005}));
} }
TEST_CASE(lstm_forward) TEST_CASE(lstm_forward)
...@@ -3120,7 +3118,7 @@ TEST_CASE(lstm_forward) ...@@ -3120,7 +3118,7 @@ TEST_CASE(lstm_forward)
0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434, 0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434,
0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113, 0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113,
0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607}; 0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// forward, last_output as program output // forward, last_output as program output
...@@ -3173,7 +3171,7 @@ TEST_CASE(lstm_forward) ...@@ -3173,7 +3171,7 @@ TEST_CASE(lstm_forward)
0.0342236, 0.0342236,
-0.198664, -0.198664,
0.0702607}; 0.0702607};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// forward, last_cell_output as program output // forward, last_cell_output as program output
...@@ -3226,7 +3224,7 @@ TEST_CASE(lstm_forward) ...@@ -3226,7 +3224,7 @@ TEST_CASE(lstm_forward)
0.078598, 0.078598,
-0.64457, -0.64457,
0.119811}; 0.119811};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
} }
...@@ -3348,7 +3346,7 @@ TEST_CASE(lstm_forward_more) ...@@ -3348,7 +3346,7 @@ TEST_CASE(lstm_forward_more)
0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202, 0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202,
0.0828391, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, 0.0828391, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774}; 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// forward, 8 args // forward, 8 args
...@@ -3397,7 +3395,7 @@ TEST_CASE(lstm_forward_more) ...@@ -3397,7 +3395,7 @@ TEST_CASE(lstm_forward_more)
0.218258, 0.0944405, 0.0431211, -0.132394, 0.103489, 0.0142918, -0.123408, 0.218258, 0.0944405, 0.0431211, -0.132394, 0.103489, 0.0142918, -0.123408,
0.0401075, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.0401075, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544,
0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723}; 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// forward, last_output as program output, sequence length shorter // forward, last_output as program output, sequence length shorter
...@@ -3459,7 +3457,7 @@ TEST_CASE(lstm_forward_more) ...@@ -3459,7 +3457,7 @@ TEST_CASE(lstm_forward_more)
0.0342236, 0.0342236,
-0.198664, -0.198664,
0.0702607}; 0.0702607};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// seq_len = 1 // seq_len = 1
...@@ -3517,7 +3515,7 @@ TEST_CASE(lstm_forward_more) ...@@ -3517,7 +3515,7 @@ TEST_CASE(lstm_forward_more)
-0.121195, -0.121195,
-0.4065, -0.4065,
-0.252054}; -0.252054};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -3647,7 +3645,7 @@ TEST_CASE(lstm_reverse) ...@@ -3647,7 +3645,7 @@ TEST_CASE(lstm_reverse)
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; 0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// reverse, sequence lengths are the same, but less than max_seq_lens // reverse, sequence lengths are the same, but less than max_seq_lens
...@@ -3705,7 +3703,7 @@ TEST_CASE(lstm_reverse) ...@@ -3705,7 +3703,7 @@ TEST_CASE(lstm_reverse)
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0}; 0.0, 0.0};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// variable sequence lengths // variable sequence lengths
...@@ -3755,7 +3753,7 @@ TEST_CASE(lstm_reverse) ...@@ -3755,7 +3753,7 @@ TEST_CASE(lstm_reverse)
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0}; 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// reverse, 3 args, last cell output as program output // reverse, 3 args, last cell output as program output
...@@ -3769,13 +3767,13 @@ TEST_CASE(lstm_reverse) ...@@ -3769,13 +3767,13 @@ TEST_CASE(lstm_reverse)
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"), migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"), migraphx::make_op("tanh"),
migraphx::make_op("tanh")})}, migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}, {"clip", clip},
{"input_forget", 0}}), {"input_forget", 0}}),
seq, seq,
w, w,
r); r);
...@@ -3797,7 +3795,7 @@ TEST_CASE(lstm_reverse) ...@@ -3797,7 +3795,7 @@ TEST_CASE(lstm_reverse)
0.141613, 0.141613,
0.348002, 0.348002,
0.667298}; 0.667298};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// reverse, 3 args, 0 actv function // reverse, 3 args, 0 actv function
...@@ -3811,10 +3809,10 @@ TEST_CASE(lstm_reverse) ...@@ -3811,10 +3809,10 @@ TEST_CASE(lstm_reverse)
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {}}, {"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip}, {"clip", clip},
{"input_forget", 0}}), {"input_forget", 0}}),
seq, seq,
w, w,
r); r);
...@@ -3836,7 +3834,7 @@ TEST_CASE(lstm_reverse) ...@@ -3836,7 +3834,7 @@ TEST_CASE(lstm_reverse)
0.141613, 0.141613,
0.348002, 0.348002,
0.667298}; 0.667298};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
} }
...@@ -3954,7 +3952,7 @@ TEST_CASE(lstm_reverse_actv) ...@@ -3954,7 +3952,7 @@ TEST_CASE(lstm_reverse_actv)
0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216, 0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216,
0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634, 0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502}; 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// reverse, 3 args, 2 actv functions // reverse, 3 args, 2 actv functions
...@@ -3995,7 +3993,7 @@ TEST_CASE(lstm_reverse_actv) ...@@ -3995,7 +3993,7 @@ TEST_CASE(lstm_reverse_actv)
0.233866, 0.233866,
0.48646, 0.48646,
0.481844}; 0.481844};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// reverse, 3 args, seq_len = 1, concatenation of hidden states as program output // reverse, 3 args, seq_len = 1, concatenation of hidden states as program output
...@@ -4041,7 +4039,7 @@ TEST_CASE(lstm_reverse_actv) ...@@ -4041,7 +4039,7 @@ TEST_CASE(lstm_reverse_actv)
0.070535, 0.070535,
0.327809, 0.327809,
0.407388}; 0.407388};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
} }
...@@ -4168,7 +4166,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -4168,7 +4166,7 @@ TEST_CASE(lstm_bidirectional)
0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723, 0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508,
-0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// last hidden state as program output // last hidden state as program output
...@@ -4187,13 +4185,13 @@ TEST_CASE(lstm_bidirectional) ...@@ -4187,13 +4185,13 @@ TEST_CASE(lstm_bidirectional)
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"), migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"), migraphx::make_op("tanh"),
migraphx::make_op("tanh")})}, migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}, {"clip", clip},
{"input_forget", 0}}), {"input_forget", 0}}),
seq, seq,
w, w,
r, r,
...@@ -4211,7 +4209,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -4211,7 +4209,7 @@ TEST_CASE(lstm_bidirectional)
-0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188, 0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188,
0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694}; 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// last cell output as program output // last cell output as program output
...@@ -4230,13 +4228,13 @@ TEST_CASE(lstm_bidirectional) ...@@ -4230,13 +4228,13 @@ TEST_CASE(lstm_bidirectional)
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"), migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"), migraphx::make_op("tanh"),
migraphx::make_op("tanh")})}, migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}, {"clip", clip},
{"input_forget", 0}}), {"input_forget", 0}}),
seq, seq,
w, w,
r, r,
...@@ -4254,7 +4252,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -4254,7 +4252,7 @@ TEST_CASE(lstm_bidirectional)
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934, -0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334, 0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713}; 1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// 3 args, concatenation of hidden states as program output // 3 args, concatenation of hidden states as program output
...@@ -4297,7 +4295,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -4297,7 +4295,7 @@ TEST_CASE(lstm_bidirectional)
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348, -0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472}; -0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// sequence length is 1, contenation of hidden state as program output // sequence length is 1, contenation of hidden state as program output
...@@ -4334,7 +4332,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -4334,7 +4332,7 @@ TEST_CASE(lstm_bidirectional)
-0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617,
-0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239, -0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239,
-0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388}; -0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
} }
...@@ -4486,9 +4484,9 @@ TEST_CASE(lstm_bidirectional_var_seq_lens) ...@@ -4486,9 +4484,9 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
0.391174, 0.0308845, -0.561745, 0.0730323, -0.326822, 0.301121, 0.219523, 0.415242, 0.391174, 0.0308845, -0.561745, 0.0730323, -0.326822, 0.301121, 0.219523, 0.415242,
2.08242, 0.442513, 0.187127, 0.0577626, -0.611307, 0.55454, 0.4364, 0.509436}; 2.08242, 0.442513, 0.187127, 0.0577626, -0.611307, 0.55454, 0.4364, 0.509436};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_range(last_cell_data, last_cell_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_cell_data, last_cell_data_gold));
} }
// last cell output as program output // last cell output as program output
...@@ -4573,9 +4571,9 @@ TEST_CASE(lstm_bidirectional_var_seq_lens) ...@@ -4573,9 +4571,9 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934, -0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334, 0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713}; 1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_range(lco_data, lco_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lco_data, lco_data_gold));
} }
} }
...@@ -4660,7 +4658,7 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4660,7 +4658,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348, -0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472}; -0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// 3 args, 1 actv func // 3 args, 1 actv func
...@@ -4700,7 +4698,7 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4700,7 +4698,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.450186, 0.263538, 0.402895, 0.216177, 0.267257, 0.342535, 0.257797, 0.268563, 0.450186, 0.263538, 0.402895, 0.216177, 0.267257, 0.342535, 0.257797, 0.268563,
0.193043, 0.275645, 0.167678, 0.350889, 0.334143, 0.309444, 0.174822, 0.251634, 0.193043, 0.275645, 0.167678, 0.350889, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502}; 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// 3 args, 2 actv func // 3 args, 2 actv func
...@@ -4714,12 +4712,12 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4714,12 +4712,12 @@ TEST_CASE(lstm_bidirectional_actv_func)
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"), migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})}, migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}, {"clip", clip},
{"input_forget", 0}}), {"input_forget", 0}}),
seq, seq,
w, w,
r); r);
...@@ -4733,7 +4731,7 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4733,7 +4731,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// 3 args, 4 actv func // 3 args, 4 actv func
...@@ -4747,15 +4745,15 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4747,15 +4745,15 @@ TEST_CASE(lstm_bidirectional_actv_func)
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value( migraphx::to_value(
std::vector<migraphx::operation>{migraphx::make_op("sigmoid"), std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"), migraphx::make_op("tanh"),
migraphx::make_op("tanh"), migraphx::make_op("tanh"),
migraphx::make_op("sigmoid")})}, migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}, {"clip", clip},
{"input_forget", 0}}), {"input_forget", 0}}),
seq, seq,
w, w,
r); r);
...@@ -4769,7 +4767,7 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4769,7 +4767,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661,
0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186}; 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// 3 args, 5 actv func // 3 args, 5 actv func
...@@ -4783,15 +4781,15 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4783,15 +4781,15 @@ TEST_CASE(lstm_bidirectional_actv_func)
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"), migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"), migraphx::make_op("tanh"),
migraphx::make_op("tanh"), migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"), migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})}, migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}, {"clip", clip},
{"input_forget", 0}}), {"input_forget", 0}}),
seq, seq,
w, w,
r); r);
...@@ -4805,7 +4803,7 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4805,7 +4803,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// 3 args, 6 actv func // 3 args, 6 actv func
...@@ -4819,16 +4817,16 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4819,16 +4817,16 @@ TEST_CASE(lstm_bidirectional_actv_func)
migraphx::make_op( migraphx::make_op(
"lstm", "lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
{"actv_func", {"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"), migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"), migraphx::make_op("tanh"),
migraphx::make_op("tanh"), migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"), migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"), migraphx::make_op("tanh"),
migraphx::make_op("tanh")})}, migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip}, {"clip", clip},
{"input_forget", 0}}), {"input_forget", 0}}),
seq, seq,
w, w,
r); r);
...@@ -4842,7 +4840,7 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4842,7 +4840,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
} }
...@@ -4987,7 +4985,5 @@ TEST_CASE(lstm_fp16) ...@@ -4987,7 +4985,5 @@ TEST_CASE(lstm_fp16)
0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434, 0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434,
0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113, 0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113,
0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607}; 0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold, 5e4)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold, 5e4));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment