Commit b9d37172 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 1af66a1c ea62d7aa
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(reverse_test_axis0)
{
migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}};
std::vector<float> data(32);
std::iota(data.begin(), data.end(), 1);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{in_shape, data});
std::vector<int> axes = {0};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::swap_ranges(gold.begin(), gold.begin() + 16, gold.begin() + 16);
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reverse_test_axis1)
{
migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}};
std::vector<float> data(32);
std::iota(data.begin(), data.end(), 1);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{in_shape, data});
std::vector<int> axes = {1};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::reverse(gold.begin(), gold.begin() + 16);
std::reverse(gold.end() - 16, gold.end());
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reverse_test_axis10)
{
migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}};
std::vector<float> data(32);
std::iota(data.begin(), data.end(), 1);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{in_shape, data});
std::vector<int> axes = {1, 0};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::reverse(gold.begin(), gold.begin() + 16);
std::reverse(gold.end() - 16, gold.end());
std::swap_ranges(gold.begin(), gold.begin() + 16, gold.begin() + 16);
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
......@@ -21,18 +21,13 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iostream>
#include <vector>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/serialize.hpp>
#include "test.hpp"
......@@ -150,8 +145,8 @@ TEST_CASE(rnn_forward)
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
}
{
......@@ -211,8 +206,8 @@ TEST_CASE(rnn_forward)
0.44193283,
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
{
......@@ -271,8 +266,8 @@ TEST_CASE(rnn_forward)
0};
std::vector<float> last_output_data_gold{
0.034457, 0.191679, -0.394683, -0.308897, -0.371446, 0.317082, 0.131042, -0.18736};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// 3 args
......@@ -302,7 +297,7 @@ TEST_CASE(rnn_forward)
std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
// seq_len = 1
......@@ -349,7 +344,7 @@ TEST_CASE(rnn_forward)
0.31708236,
0.13104209,
-0.18736027};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
......@@ -400,7 +395,6 @@ TEST_CASE(rnn_reverse)
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
// concatenation of hidden states as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input});
......@@ -444,7 +438,7 @@ TEST_CASE(rnn_reverse)
0.46251031,
-0.20639211,
0.37488942};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// rnn last output as program output
......@@ -487,7 +481,7 @@ TEST_CASE(rnn_reverse)
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
// rnn hidden states and last hidden state output as program outputs
......@@ -550,8 +544,8 @@ TEST_CASE(rnn_reverse)
0.14365635,
0.14803654};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
// rnn hidden states and last hidden state output as program outputs
......@@ -612,8 +606,8 @@ TEST_CASE(rnn_reverse)
std::vector<float> last_output_data_gold{
-0.293853, 0.167968, 0.51076, 0.402587, -0.0070999, 0.46251, -0.206392, 0.374889};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
}
......@@ -724,8 +718,8 @@ TEST_CASE(rnn_bidirectional)
0.14365635,
0.14803654};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
// last rnn output for program output
......@@ -790,8 +784,8 @@ TEST_CASE(rnn_bidirectional)
0.143656,
0.148037};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
// 4 args
......@@ -841,7 +835,7 @@ TEST_CASE(rnn_bidirectional)
0.14365635,
0.14803654};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
// 3 args
......@@ -876,7 +870,7 @@ TEST_CASE(rnn_bidirectional)
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
}
// concatenation of hidden state for program output
......@@ -929,7 +923,7 @@ TEST_CASE(rnn_bidirectional)
-0.20639211,
0.37488942};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
......@@ -1014,7 +1008,10 @@ TEST_CASE(rnn_fp16)
std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold, 5e4));
EXPECT(migraphx::verify::verify_range_with_tolerance(
last_output_data,
migraphx::verify::expected{last_output_data_gold},
migraphx::verify::tolerance{0.005}));
}
TEST_CASE(gru_forward)
......@@ -1112,7 +1109,7 @@ TEST_CASE(gru_forward)
0.48523626, 0.60002893, -0.3969709, 0.43360898, 0.35775262, 0.23280787,
-0.52179873, -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// last output for output
......@@ -1158,7 +1155,7 @@ TEST_CASE(gru_forward)
0.51757574,
0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// two rnn_last_hs_output operators after gru
......@@ -1205,7 +1202,7 @@ TEST_CASE(gru_forward)
0.51757574,
0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// last output for output, linear_before_reset = 0
......@@ -1251,7 +1248,7 @@ TEST_CASE(gru_forward)
0.6014447,
0.43445644};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
......@@ -1336,7 +1333,7 @@ TEST_CASE(gru_forward_args)
-0.232523, 0.00214573, 0.231693, -0.160475, -0.518952,
0.0467166, 0.12327, -0.374162, 0.137778, 0.251976};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// 4 args (bias is used)
......@@ -1379,7 +1376,7 @@ TEST_CASE(gru_forward_args)
-0.416866, 0.377186, 0.32922, 0.162214, -0.519973,
-0.140072, 0.465076, -0.229563, 0.500164, 0.195166};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// 4 args (ih is used)
......@@ -1423,7 +1420,7 @@ TEST_CASE(gru_forward_args)
-0.197, 0.0885705, 0.269396, -0.0414511, -0.515137,
-0.03075, 0.158326, -0.296488, 0.177983, 0.519498};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
......@@ -1525,7 +1522,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.51757574,
0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// 1 activation function (sigmoid) specified
......@@ -1566,7 +1563,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.35652235, 0.6033026, 0.52634895, 0.5815402, 0.3001663,
0.39814138, 0.4354002, 0.4310627, 0.6708563, 0.7509278};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// 1 activation function (tanh) specified
......@@ -1611,7 +1608,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.65615714,
0.53612584};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// seq length of 1
......@@ -1661,7 +1658,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.6104771,
0.79759157};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
......@@ -1736,12 +1733,12 @@ TEST_CASE(gru_reverse)
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"linear_before_reset", 1}}),
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"linear_before_reset", 1}}),
seq,
w,
r,
......@@ -1777,8 +1774,8 @@ TEST_CASE(gru_reverse)
0.55703,
0.54711};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
}
// variable input sequence length
......@@ -1838,8 +1835,8 @@ TEST_CASE(gru_reverse)
0.558397,
0.664423};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
}
// last output for output, linear_before_reset = 0
......@@ -1885,7 +1882,7 @@ TEST_CASE(gru_reverse)
0.646604,
0.463943};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// no activation function specified, so default is used.
......@@ -1924,7 +1921,7 @@ TEST_CASE(gru_reverse)
-0.329512, 0.476095, 0.284044, 0.392077, -0.369226,
-0.3275, -0.027301, 0.143774, 0.655686, 0.782831};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// seq length of 1
......@@ -1974,7 +1971,7 @@ TEST_CASE(gru_reverse)
0.610477,
0.797592};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
......@@ -2067,12 +2064,12 @@ TEST_CASE(gru_bidirectional)
migraphx::make_op(
"gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"linear_before_reset", 1}}),
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"linear_before_reset", 1}}),
seq,
w,
r,
......@@ -2105,8 +2102,8 @@ TEST_CASE(gru_bidirectional)
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
}
// same input sequence length, but shorter than max squence length
......@@ -2174,8 +2171,8 @@ TEST_CASE(gru_bidirectional)
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
}
// variable input sequence lengths
......@@ -2233,8 +2230,8 @@ TEST_CASE(gru_bidirectional)
-0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078,
0.182457, 0.304506, 0.313825, 0.397697, 0.300873};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
}
// last output for output, linear_before_reset = 0
......@@ -2274,7 +2271,7 @@ TEST_CASE(gru_bidirectional)
-0.10688055, -0.4767866, 0.6317833, 0.00286336, 0.53692746, -0.00617076, 0.04564289,
-0.18030001, 0.39584228, 0.53879917, 0.384983, 0.2759448, 0.11611474};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
......@@ -2376,7 +2373,7 @@ TEST_CASE(gru_bidirectional_args)
0.469122, -0.306578, -0.221095, -0.106449, -0.248934, -0.00682121, 0.288407,
0.198708, 0.0695644, 0.211621, 0.00246037};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// 4 args (bias is used)
......@@ -2427,7 +2424,7 @@ TEST_CASE(gru_bidirectional_args)
0.476508, -0.313413, -0.0361821, -0.173037, -0.235731, -0.163113, 0.349008,
0.248674, -0.0295413, 0.291437, -0.165005};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// 4 args (ih is used)
......@@ -2475,7 +2472,7 @@ TEST_CASE(gru_bidirectional_args)
0.233106, 0.32996, -0.17175, 0.0190231, -0.154805, -0.205631, -0.405354,
0.519054, -0.380409, -0.0350301, -0.00633752, 0.403791, 0.181883, -0.0977917,
-0.0339407, 0.413089, 0.721238, 0.431879};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
......@@ -2589,7 +2586,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// 1 activation function (sigmoid) specified
......@@ -2632,7 +2629,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.463795, 0.539649, 0.487682, 0.554471, 0.395916, 0.430744, 0.415923, 0.424275,
0.409655, 0.698256, 0.126883, 0.554374, 0.216137, 0.671491, 0.263833, 0.0678646,
0.132732, 0.477083, 0.802206, 0.626802};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// 1 activation function (tanh) specified
......@@ -2676,7 +2673,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.66716, -0.704461, -0.393346, -0.627123, 0.210395, 0.0563026, 0.31419,
0.759629, 0.000258222, 0.350835, -0.682684};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// 3 activation functions specified
......@@ -2716,7 +2713,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
1.15142, 0.457633, 0.300962, 0.361245, 0.666199,
0.330446, 0.301982, -0.443763, -0.0655817, -0.326473,
0.861394, 0.560799, -0.101768, 0.145142, 0.128956};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// 4 activation functions all specified
......@@ -2764,7 +2761,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665,
0.079043, 0.322652, 0.752701, 0.243775};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
......@@ -2879,7 +2876,7 @@ TEST_CASE(gru_bidirectional_seq_1)
-0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078,
-0.144492, -0.0115366, 0.409153, 0.487015, 0.550755};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
TEST_CASE(gru_fp16)
......@@ -2989,7 +2986,8 @@ TEST_CASE(gru_fp16)
-0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873,
-0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold, 5e4));
EXPECT(migraphx::verify::verify_range_with_tolerance(
hs_data, migraphx::verify::expected{hs_data_gold}, migraphx::verify::tolerance{0.005}));
}
TEST_CASE(lstm_forward)
......@@ -3120,7 +3118,7 @@ TEST_CASE(lstm_forward)
0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434,
0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113,
0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// forward, last_output as program output
......@@ -3173,7 +3171,7 @@ TEST_CASE(lstm_forward)
0.0342236,
-0.198664,
0.0702607};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// forward, last_cell_output as program output
......@@ -3226,7 +3224,7 @@ TEST_CASE(lstm_forward)
0.078598,
-0.64457,
0.119811};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
......@@ -3348,7 +3346,7 @@ TEST_CASE(lstm_forward_more)
0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202,
0.0828391, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// forward, 8 args
......@@ -3397,7 +3395,7 @@ TEST_CASE(lstm_forward_more)
0.218258, 0.0944405, 0.0431211, -0.132394, 0.103489, 0.0142918, -0.123408,
0.0401075, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544,
0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// forward, last_output as program output, sequence length shorter
......@@ -3459,7 +3457,7 @@ TEST_CASE(lstm_forward_more)
0.0342236,
-0.198664,
0.0702607};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// seq_len = 1
......@@ -3517,7 +3515,7 @@ TEST_CASE(lstm_forward_more)
-0.121195,
-0.4065,
-0.252054};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
......@@ -3647,7 +3645,7 @@ TEST_CASE(lstm_reverse)
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, sequence lengths are the same, but less than max_seq_lens
......@@ -3705,7 +3703,7 @@ TEST_CASE(lstm_reverse)
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// variable sequence lengths
......@@ -3755,7 +3753,7 @@ TEST_CASE(lstm_reverse)
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, last cell output as program output
......@@ -3769,13 +3767,13 @@ TEST_CASE(lstm_reverse)
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
......@@ -3797,7 +3795,7 @@ TEST_CASE(lstm_reverse)
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, 0 actv function
......@@ -3811,10 +3809,10 @@ TEST_CASE(lstm_reverse)
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
......@@ -3836,7 +3834,7 @@ TEST_CASE(lstm_reverse)
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
......@@ -3954,7 +3952,7 @@ TEST_CASE(lstm_reverse_actv)
0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216,
0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, 2 actv functions
......@@ -3995,7 +3993,7 @@ TEST_CASE(lstm_reverse_actv)
0.233866,
0.48646,
0.481844};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, seq_len = 1, concatenation of hidden states as program output
......@@ -4041,7 +4039,7 @@ TEST_CASE(lstm_reverse_actv)
0.070535,
0.327809,
0.407388};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
......@@ -4168,7 +4166,7 @@ TEST_CASE(lstm_bidirectional)
0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508,
-0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// last hidden state as program output
......@@ -4187,13 +4185,13 @@ TEST_CASE(lstm_bidirectional)
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
......@@ -4211,7 +4209,7 @@ TEST_CASE(lstm_bidirectional)
-0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188,
0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// last cell output as program output
......@@ -4230,13 +4228,13 @@ TEST_CASE(lstm_bidirectional)
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
......@@ -4254,7 +4252,7 @@ TEST_CASE(lstm_bidirectional)
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// 3 args, concatenation of hidden states as program output
......@@ -4297,7 +4295,7 @@ TEST_CASE(lstm_bidirectional)
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// sequence length is 1, contenation of hidden state as program output
......@@ -4334,7 +4332,7 @@ TEST_CASE(lstm_bidirectional)
-0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617,
-0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239,
-0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
......@@ -4486,9 +4484,9 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
0.391174, 0.0308845, -0.561745, 0.0730323, -0.326822, 0.301121, 0.219523, 0.415242,
2.08242, 0.442513, 0.187127, 0.0577626, -0.611307, 0.55454, 0.4364, 0.509436};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_range(last_cell_data, last_cell_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_cell_data, last_cell_data_gold));
}
// last cell output as program output
......@@ -4573,9 +4571,9 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_range(lco_data, lco_data_gold));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lco_data, lco_data_gold));
}
}
......@@ -4660,7 +4658,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// 3 args, 1 actv func
......@@ -4700,7 +4698,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.450186, 0.263538, 0.402895, 0.216177, 0.267257, 0.342535, 0.257797, 0.268563,
0.193043, 0.275645, 0.167678, 0.350889, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// 3 args, 2 actv func
......@@ -4714,12 +4712,12 @@ TEST_CASE(lstm_bidirectional_actv_func)
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
......@@ -4733,7 +4731,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// 3 args, 4 actv func
......@@ -4747,15 +4745,15 @@ TEST_CASE(lstm_bidirectional_actv_func)
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(
{"actv_func",
migraphx::to_value(
std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
......@@ -4769,7 +4767,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661,
0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// 3 args, 5 actv func
......@@ -4783,15 +4781,15 @@ TEST_CASE(lstm_bidirectional_actv_func)
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
......@@ -4805,7 +4803,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// 3 args, 6 actv func
......@@ -4819,16 +4817,16 @@ TEST_CASE(lstm_bidirectional_actv_func)
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh"),
migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
......@@ -4842,7 +4840,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
......@@ -4987,7 +4985,5 @@ TEST_CASE(lstm_fp16)
0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434,
0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113,
0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold, 5e4));
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold, 5e4));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(roialign_out_of_bound_test)
{
auto create_program = [](const std::string& trans_mode = "half_pixel") {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape x_s{migraphx::shape::float_type, {1, 1, 10, 10}};
std::vector<float> x_vec = {
0.2764, 0.7150, 0.1958, 0.3416, 0.4638, 0.0259, 0.2963, 0.6518, 0.4856, 0.7250,
0.9637, 0.0895, 0.2919, 0.6753, 0.0234, 0.6132, 0.8085, 0.5324, 0.8992, 0.4467,
0.3265, 0.8479, 0.9698, 0.2471, 0.9336, 0.1878, 0.4766, 0.4308, 0.3400, 0.2162,
0.0206, 0.1720, 0.2155, 0.4394, 0.0653, 0.3406, 0.7724, 0.3921, 0.2541, 0.5799,
0.4062, 0.2194, 0.4473, 0.4687, 0.7109, 0.9327, 0.9815, 0.6320, 0.1728, 0.6119,
0.3097, 0.1283, 0.4984, 0.5068, 0.4279, 0.0173, 0.4388, 0.0430, 0.4671, 0.7119,
0.1011, 0.8477, 0.4726, 0.1777, 0.9923, 0.4042, 0.1869, 0.7795, 0.9946, 0.9689,
0.1366, 0.3671, 0.7011, 0.6234, 0.9867, 0.5585, 0.6985, 0.5609, 0.8788, 0.9928,
0.5697, 0.8511, 0.6711, 0.9406, 0.8751, 0.7496, 0.1650, 0.1049, 0.1559, 0.2514,
0.7012, 0.4056, 0.7879, 0.3461, 0.0415, 0.2998, 0.5094, 0.3727, 0.5482, 0.0502};
migraphx::shape roi_s{migraphx::shape::float_type, {3, 4}};
std::vector<float> roi_vec = {0, 0, 9.99, 9.99, 0, 5, 4, 9, 5, 5, 9.9, 9.9};
migraphx::shape ind_s{migraphx::shape::int64_type, {3}};
std::vector<int64_t> ind_vec = {0, 0, 0};
auto x = mm->add_literal(migraphx::literal(x_s, x_vec));
auto roi = mm->add_literal(migraphx::literal(roi_s, roi_vec));
auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec));
auto r =
mm->add_instruction(migraphx::make_op("roialign",
{{"coordinate_transformation_mode", trans_mode},
{"spatial_scale", 5.0},
{"output_height", 1},
{"output_width", 1},
{"sampling_ratio", 1}}),
x,
roi,
ind);
mm->add_return({r});
return p;
};
{
auto p = create_program("half_pixel");
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0f, 0.0f, 0.0f};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
}
TEST_CASE(roialign_test)
{
auto create_program = [](const std::string& trans_mode = "half_pixel",
const migraphx::op::pooling_mode pooling_mode =
migraphx::op::pooling_mode::average,
int64_t sampling_ratio = 2) {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape x_s{migraphx::shape::float_type, {1, 1, 10, 10}};
std::vector<float> x_vec = {
0.2764, 0.7150, 0.1958, 0.3416, 0.4638, 0.0259, 0.2963, 0.6518, 0.4856, 0.7250,
0.9637, 0.0895, 0.2919, 0.6753, 0.0234, 0.6132, 0.8085, 0.5324, 0.8992, 0.4467,
0.3265, 0.8479, 0.9698, 0.2471, 0.9336, 0.1878, 0.4766, 0.4308, 0.3400, 0.2162,
0.0206, 0.1720, 0.2155, 0.4394, 0.0653, 0.3406, 0.7724, 0.3921, 0.2541, 0.5799,
0.4062, 0.2194, 0.4473, 0.4687, 0.7109, 0.9327, 0.9815, 0.6320, 0.1728, 0.6119,
0.3097, 0.1283, 0.4984, 0.5068, 0.4279, 0.0173, 0.4388, 0.0430, 0.4671, 0.7119,
0.1011, 0.8477, 0.4726, 0.1777, 0.9923, 0.4042, 0.1869, 0.7795, 0.9946, 0.9689,
0.1366, 0.3671, 0.7011, 0.6234, 0.9867, 0.5585, 0.6985, 0.5609, 0.8788, 0.9928,
0.5697, 0.8511, 0.6711, 0.9406, 0.8751, 0.7496, 0.1650, 0.1049, 0.1559, 0.2514,
0.7012, 0.4056, 0.7879, 0.3461, 0.0415, 0.2998, 0.5094, 0.3727, 0.5482, 0.0502};
migraphx::shape roi_s{migraphx::shape::float_type, {3, 4}};
std::vector<float> roi_vec = {0, 0, 9, 9, 0, 5, 4, 9, 5, 5, 9, 9};
migraphx::shape ind_s{migraphx::shape::int64_type, {3}};
std::vector<int64_t> ind_vec = {0, 0, 0};
auto x = mm->add_literal(migraphx::literal(x_s, x_vec));
auto roi = mm->add_literal(migraphx::literal(roi_s, roi_vec));
auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec));
auto r =
mm->add_instruction(migraphx::make_op("roialign",
{{"coordinate_transformation_mode", trans_mode},
{"spatial_scale", 1.0},
{"output_height", 5},
{"output_width", 5},
{"sampling_ratio", sampling_ratio},
{"mode", pooling_mode}}),
x,
roi,
ind);
mm->add_return({r});
return p;
};
{
auto p = create_program("output_half_pixel");
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
0.466421425, 0.446552634, 0.340521216, 0.568848491, 0.606780827, 0.371379346,
0.429571986, 0.383519977, 0.556241512, 0.351050019, 0.27680251, 0.488286227,
0.522200167, 0.552770197, 0.417057365, 0.471240699, 0.4844096, 0.690457463,
0.492039412, 0.877398551, 0.623889625, 0.712461948, 0.628926516, 0.335504025,
0.349469036, 0.302179992, 0.43046391, 0.469585985, 0.39774403, 0.542259991,
0.365552008, 0.704923987, 0.516481996, 0.317131996, 0.701444089, 0.291239977,
0.505897999, 0.647610962, 0.623489916, 0.829879999, 0.591567993, 0.738860011,
0.704825997, 0.837148011, 0.889315963, 0.622680008, 0.615276039, 0.709713995,
0.615356028, 0.458524048, 0.238451958, 0.337952018, 0.371693879, 0.609999895,
0.760059953, 0.376724035, 0.378532052, 0.71468991, 0.924308002, 0.972783983,
0.574903965, 0.582623959, 0.570936024, 0.761904061, 0.876998067, 0.535508037,
0.256580025, 0.214098021, 0.279604018, 0.360000014, 0.436488032, 0.350427985,
0.288755983, 0.366139978, 0.234920025};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
{
auto p = create_program("half_pixel");
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
0.517783, 0.343411, 0.322905, 0.447362, 0.634375, 0.40308, 0.536647, 0.442791,
0.486144, 0.402313, 0.251194, 0.400154, 0.515524, 0.695369, 0.346537, 0.33504,
0.460099, 0.588069, 0.343863, 0.684932, 0.49319, 0.714058, 0.821744, 0.471935,
0.403946, 0.306955, 0.218678, 0.33369, 0.488001, 0.486962, 0.18709, 0.49142,
0.55611, 0.419167, 0.368608, 0.143278, 0.460835, 0.597125, 0.53096, 0.498207,
0.278818, 0.438569, 0.6022, 0.700038, 0.752436, 0.577385, 0.702383, 0.725097,
0.733754, 0.816304, 0.23933, 0.407514, 0.337893, 0.252521, 0.474335, 0.367075,
0.270168, 0.41051, 0.64189, 0.830777, 0.55564, 0.454295, 0.55645, 0.75015,
0.929997, 0.66257, 0.561664, 0.481275, 0.495449, 0.666306, 0.663573, 0.372107,
0.205603, 0.192776, 0.247849};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
{
auto p = create_program("half_pixel", migraphx::op::pooling_mode::max, 0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
0.819145, 0.373103, 0.258302, 0.515419, 0.726104, 0.540536, 0.545512, 0.38511,
0.376545, 0.274635, 0.22341, 0.184511, 0.230843, 0.404869, 0.29546, 0.540409,
0.265838, 0.409324, 0.213915, 0.708654, 0.687264, 0.580821, 0.461283, 0.462879,
0.709632, 0.27873, 0.083619, 0.22428, 0.313992, 0.410508, 0.0929099, 0.415373,
0.296695, 0.231574, 0.136836, 0.0683, 0.296695, 0.211925, 0.245385, 0.28053,
0.17091, 0.179879, 0.245385, 0.343539, 0.392742, 0.51273, 0.536193, 0.382995,
0.422793, 0.761886, 0.0839429, 0.276444, 0.19746, 0.126117, 0.378351, 0.254646,
0.092148, 0.272825, 0.381955, 0.626599, 0.251325, 0.244475, 0.194875, 0.272825,
0.44757, 0.351855, 0.342265, 0.244475, 0.274841, 0.553644, 0.607176, 0.202392,
0.07425, 0.066087, 0.126279};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(round_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {9}};
auto l =
mm->add_literal(migraphx::literal{s, {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}});
mm->add_instruction(migraphx::make_op("round"), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(round_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{4, 10};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("round"), input);
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data{1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {9}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(rsqrt_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = mm->add_literal(migraphx::literal{s, {4.0, 16.0, 64.0}});
mm->add_instruction(migraphx::make_op("rsqrt"), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(rsqrt_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("rsqrt"), input);
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data{4.0, 16.0, 64.0};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(imagescaler_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 3, 2, 2}};
auto img = mm->add_literal(migraphx::literal{s,
{0.2,
0.3,
0.5,
0.4,
0.7,
0.8,
0.1,
0.9,
0.15,
0.25,
0.35,
0.45}});
auto scale_val = mm->add_literal(2.f);
auto scaled_tensor = mm->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), scale_val);
auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), img, scaled_tensor);
auto bias_vals = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), bias_vals);
mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.41,
0.61,
1.01,
0.81,
1.42,
1.62,
0.22,
1.82,
0.33,
0.53,
0.73,
0.93};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
// reduction_mode: "scatter_none", "scatter_add", "scatter_mul"
migraphx::program create_scatter_program(const std::string& reduction_mode, int axis)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
std::vector<float> vd(sd.elements(), 0.0f);
migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1};
migraphx::shape su{migraphx::shape::float_type, {2, 3}};
std::vector<float> vu = {1.0, 1.1, 1.2, 2.0, 2.1, 2.2};
auto ld = mm->add_literal(migraphx::literal{sd, vd});
auto li = mm->add_literal(migraphx::literal{si, vi});
auto lu = mm->add_literal(migraphx::literal{su, vu});
// scatter_none, formerly the scatter op
auto r = mm->add_instruction(migraphx::make_op(reduction_mode, {{"axis", axis}}), ld, li, lu);
mm->add_return({r});
return p;
}
TEST_CASE(scatter_ax0_test)
{
// this tests what used to be the only scatter op, now changed to 3 sub-ops
// which have their own test case
{
migraphx::program p = create_scatter_program("scatter_none", 0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
}
TEST_CASE(scatter_ax_neg_test)
{
{
migraphx::program p = create_scatter_program("scatter_none", -2);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
}
TEST_CASE(scatter_ax1_test)
{
{
migraphx::program p = create_scatter_program("scatter_none", 1);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.1, 1.0, 1.2, 2.0, 2.2, 2.1, 0.0, 0.0, 0.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
}
// similar to create_scatter_program but with different tensor values
// reduction_mode: "scatter_none", "scatter_add", "scatter_mul"
migraphx::program create_scatter_program2(const std::string& reduction_mode, int axis)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {1, 5}};
std::vector<float> vd({1., 2., 3., 4., 5.});
migraphx::shape si{migraphx::shape::int32_type, {1, 2}};
std::vector<int> vi = {1, 3};
migraphx::shape su{migraphx::shape::float_type, {1, 2}};
std::vector<float> vu = {1.1, 2.1};
auto ld = mm->add_literal(migraphx::literal{sd, vd});
auto li = mm->add_literal(migraphx::literal{si, vi});
auto lu = mm->add_literal(migraphx::literal{su, vu});
auto r = mm->add_instruction(migraphx::make_op(reduction_mode, {{"axis", axis}}), ld, li, lu);
mm->add_return({r});
return p;
}
TEST_CASE(scatter_reduction1_test)
{
{
// Test sub-ops for the three reduction values scatter_none, scatter_add, scatter_mul
migraphx::program p = create_scatter_program2("scatter_none", 1);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_none = {1.0, 1.1, 3.0, 2.1, 5.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_none));
}
}
TEST_CASE(scatter_reduction2_test)
{
{
migraphx::program p = create_scatter_program2("scatter_mul", 1);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_mul = {1.0, 2.2, 3.0, 8.4, 5.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_mul));
}
}
TEST_CASE(scatter_reduction3_test)
{
{
migraphx::program p = create_scatter_program2("scatter_add", 1);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_add = {1.0, 3.1, 3.0, 6.1, 5.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_add));
}
}
TEST_CASE(scatter_reduction_3x3_test)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
std::vector<float> vd(sd.elements(), 3.0f);
migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1};
migraphx::shape su{migraphx::shape::float_type, {2, 3}};
std::vector<float> vu = {1.0, 1.1, 1.2, 7.0, 7.1, 7.2};
auto ld = mm->add_literal(migraphx::literal{sd, vd});
auto li = mm->add_literal(migraphx::literal{si, vi});
auto lu = mm->add_literal(migraphx::literal{su, vu});
auto r = mm->add_instruction(migraphx::make_op("scatter_add", {{"axis", 1}}), ld, li, lu);
mm->add_return({r});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_a2 = {4.1, 4.0, 4.2, 10.0, 10.2, 10.1, 3.0, 3.0, 3.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_a2));
}
}
// create a test scatter program with a 3x3 tensor;
// su and si are transposed from previous case
migraphx::program create_scatter_program_3x3(const std::string& reduction_mode, int axis)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
std::vector<float> vd(sd.elements(), 3.0f);
migraphx::shape si{migraphx::shape::int32_type, {3, 2}};
std::vector<int> vi = {1, 0, 0, 2, 2, 1};
migraphx::shape su{migraphx::shape::float_type, {3, 2}};
std::vector<float> vu = {1.0, 7.0, 1.1, 7.1, 1.2, 7.2};
auto ld = mm->add_literal(migraphx::literal{sd, vd});
auto li = mm->add_literal(migraphx::literal{si, vi});
auto lu = mm->add_literal(migraphx::literal{su, vu});
auto r = mm->add_instruction(migraphx::make_op(reduction_mode, {{"axis", axis}}), ld, li, lu);
mm->add_return({r});
return p;
}
TEST_CASE(scatter_reduction_3x3_xpose1_test)
{
// test on vertical (0) axis. su and si are transposed from previous case
{
migraphx::program p = create_scatter_program_3x3("scatter_none", 0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_none2 = {1.1, 7.0, 3.0, 1.0, 7.2, 3.0, 1.2, 7.1, 3.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_none2));
}
}
TEST_CASE(scatter_reduction_3x3_xpose2_test)
{
// test on vertical (0) axis.
{
migraphx::program p = create_scatter_program_3x3("scatter_add", 0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_a3 = {4.1, 10.0, 3.0, 4.0, 10.2, 3.0, 4.2, 10.1, 3.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_a3));
}
}
TEST_CASE(scatter_reduction_3x3_xpose3_test)
{
{
migraphx::program p = create_scatter_program_3x3("scatter_mul", 0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_mul2 = {3.3, 21.0, 3.0, 3.0, 21.6, 3.0, 3.6, 21.3, 3.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_mul2));
}
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(scatternd_add_reduction_test)
{
// reduction = add
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {8, 1}};
migraphx::shape us{dtype, {8}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{4, 3, 1, 7, 4, 3, 1, 7};
std::vector<float> upd_vec{9, 10, 11, 12, -8, -9, -10, -11};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_add"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 3, 3, 5, 6, 6, 7, 9};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_reduction_dyn_test)
{
// reduction = add, with dynamic input shapes
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape::dynamic_dimension dd{3, 6};
migraphx::shape ds{migraphx::shape::float_type, {dd, dd, dd}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {{2, 2}, dd, dd}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
auto xupdates = mm->add_parameter("U", us);
auto scatternd_add_op = migraphx::make_op("scatternd_add");
auto scatternd = mm->add_instruction(scatternd_add_op, xdata, xindex, xupdates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4, 4, 4}}; // data
std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<uint64_t> input_index{0, 2};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {2, 4, 4}}; // updates
std::vector<float> input_updates{5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4};
params["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
params["I"] = migraphx::argument(is, input_index.data());
params["U"] = migraphx::argument(input_fixed_shape1, input_updates.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{6, 7, 8, 9, 11, 12, 13, 14, 15, 14, 13, 12, 12, 11, 10, 9,
1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
9, 8, 7, 6, 6, 5, 4, 3, 4, 5, 6, 7, 9, 10, 11, 12,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(scatternd_mul_reduction_test)
{
// reduction = mul
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
std::vector<float> upd_vec{9, 10, 11, 12};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_mul"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 22, 3, 40, 45, 6, 7, 96};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(scatternd_shapes_test_1)
{
// broadcasted input
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
std::vector<float> upd_vec{9, 10, 11, 12};
auto data = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {8}}}),
mm->add_literal(migraphx::literal{0.0f}));
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 11, 0, 10, 9, 0, 0, 12};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_shapes_test_2)
{
// non-standard shape input
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2}};
migraphx::shape is{itype, {2, 2}};
migraphx::shape us{dtype, {2}};
std::vector<float> data_vec{1, 2, 3, 4};
std::vector<int64_t> ind_vec{0, 0, 0, 1};
std::vector<float> upd_vec{5, 6};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto td = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), data);
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd = mm->add_instruction(migraphx::make_op("scatternd_none"), td, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 6, 2, 4};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_shapes_test_3)
{
// non-standard updates shape
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2, 2}};
migraphx::shape is{itype, {2, 1, 3}};
migraphx::shape us{dtype, {1, 2}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 0, 0, 1, 1, 1};
std::vector<float> upd_vec{9, 10};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto tu =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), updates);
auto scatternd = mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, tu);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 10};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_test_1)
{
// r=1, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
std::vector<float> upd_vec{9, 10, 11, 12};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 11, 3, 10, 9, 6, 7, 12};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_test_2)
{
// r=2, q=2, k=2
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2}};
migraphx::shape is{itype, {2, 2}};
migraphx::shape us{dtype, {2}};
std::vector<float> data_vec{1, 2, 3, 4};
std::vector<int64_t> ind_vec{0, 0, 0, 1};
std::vector<float> upd_vec{5, 6};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 6, 3, 4};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_test_3)
{
// r=3, q=3, k=3
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2, 2}};
migraphx::shape is{itype, {2, 1, 3}};
migraphx::shape us{dtype, {2, 1}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 0, 0, 1, 1, 1};
std::vector<float> upd_vec{9, 10};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 10};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_test_4)
{
// r=3, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {4, 4, 4}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {2, 4, 4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 2};
std::vector<float> upd_vec{5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3,
4, 4, 4, 4, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_test_5)
{
// r=5, q=1, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2, 2, 2, 2}};
migraphx::shape is{itype, {1}};
migraphx::shape us{dtype, {2, 2, 2, 2}};
std::vector<float> data_vec(32, 1);
std::vector<int64_t> ind_vec{1};
std::vector<float> upd_vec(16, 0);
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(32, 0);
std::copy(data_vec.begin(), data_vec.begin() + 16, gold.begin());
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(select_module_add_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = mm->add_literal(migraphx::literal{lit_s, {6}});
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto broadcast_lit =
submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto add_ins = submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* batch1 = create_submodule(1, "batch_1");
auto* batch2 = create_submodule(2, "batch_2");
auto* batch3 = create_submodule(3, "batch_3");
auto* batch4 = create_submodule(4, "batch_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input = mm->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module", {{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input},
{batch1, batch2, batch3, batch4});
auto ret = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm->add_return({ret});
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data{-4, 8, -1, 4, -1, 8, 8, -4};
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 4}};
params["data"] = migraphx::argument(input_fixed_shape, input_data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2, 14, 5, 10, 5, 14, 14, 2};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(select_module_reduce_test0)
{
migraphx::program p;
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto reduce_ins =
submod->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), sm_input);
auto squeeze_ins =
submod->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), reduce_ins);
submod->add_return({squeeze_ins});
return submod;
};
auto* batch1 = create_submodule(1, "batch_1");
auto* batch2 = create_submodule(2, "batch_2");
auto* batch3 = create_submodule(3, "batch_3");
auto* batch4 = create_submodule(4, "batch_4");
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}}};
auto input = mm->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module", {{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input},
{batch1, batch2, batch3, batch4});
auto ret = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm->add_return({ret});
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data{-4, 8, -1, 4, -1, 8, 8, -4};
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 2, 2}};
params["data"] = migraphx::argument(input_fixed_shape, input_data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-5, 12, 7, 4};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(select_module_reduce_test1)
{
migraphx::program p;
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto reduce_ins =
submod->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), sm_input);
auto squeeze_ins =
submod->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), reduce_ins);
submod->add_return({squeeze_ins});
return submod;
};
auto* batch1 = create_submodule(1, "batch_1");
auto* batch2 = create_submodule(2, "batch_2");
auto* batch3 = create_submodule(3, "batch_3");
auto* batch4 = create_submodule(4, "batch_4");
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}}};
auto input = mm->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module", {{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input},
{batch1, batch2, batch3, batch4});
auto ret = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm->add_return({ret});
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data{-4, 8, -1, 4, -1, 8, 8, -4, -4, 8, -1, 4, -1, 8, 8, -4};
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {4, 2, 2}};
params["data"] = migraphx::argument(input_fixed_shape, input_data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-5, 12, 7, 4, -5, 12, 7, 4};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(select_module_not_found_error)
{
migraphx::program p;
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}};
auto sm_input = submod->add_parameter("data", sm_shape);
auto reduce_ins =
submod->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), sm_input);
auto squeeze_ins =
submod->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), reduce_ins);
submod->add_return({squeeze_ins});
return submod;
};
auto* batch1 = create_submodule(1, "batch_1");
auto* batch2 = create_submodule(2, "batch_2");
auto* batch3 = create_submodule(3, "batch_3");
auto* batch4 = create_submodule(4, "batch_4");
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}}};
auto input = mm->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module", {{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input},
{batch1, batch2, batch3, batch4});
auto ret = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm->add_return({ret});
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data{-4, 8, -1, 4, -1, 8, 8, -4, -4, 8,
-1, 4, -1, 8, 8, -4, -1, 8, 8, -4};
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {5, 2, 2}};
params["data"] = migraphx::argument(input_fixed_shape, input_data.data());
EXPECT(test::throws([&] { std::ignore = p.eval(params).back(); }));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
float sigmoid(float x) { return 1 / (1 + expf(-x)); }
TEST_CASE(sigmoid_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l = mm->add_literal(migraphx::literal{s, {-1, 2, -3, 4}});
mm->add_instruction(migraphx::make_op("sigmoid"), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{sigmoid(-1), sigmoid(2), sigmoid(-3), sigmoid(4)};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(sigmoid_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4}, {2, 2}}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("sigmoid"), input);
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data{-1, 2, -3, 4};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 2}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{sigmoid(-1), sigmoid(2), sigmoid(-3), sigmoid(4)};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(sign_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {5}};
auto l = mm->add_literal(
migraphx::literal{s, {1.02481645, 0.85643062, -0.03404123, -0.92791926, 0.0}});
mm->add_instruction(migraphx::make_op("sign"), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(sign_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("sign"), input);
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data{1.02481645, 0.85643062, -0.03404123, -0.92791926, 0.0};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {5}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(sin_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
std::vector<float> data = {-1, 0, 1};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("sin"), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinf(n); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(sin_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("sin"), input);
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data = {-1, 0, 1};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = input_data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinf(n); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(sinh_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> data{-1.0, 2.0, -3.0, 4.0};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("sinh"), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinhf(n); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(sinh_dynamic_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4}, {2, 4}}};
auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-1.0, 2.0, -3.0, 4.0};
mm->add_instruction(migraphx::make_op("sinh"), input);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = input_data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinhf(n); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(slice_test_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}),
l0);
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
EXPECT(p.get_output_shapes().back() == s2);
p.compile(migraphx::make_target("ref"));
migraphx::shape sresult{migraphx::shape::int32_type, {2, 2, 2}, {4, 2, 1}};
auto result = p.eval({}).back();
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
TEST_CASE(slice_test_2)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(
migraphx::make_op("slice",
{{"axes", {0, 1, 2}}, {"starts", {0, 0, 0}}, {"ends", {2, 2, 2}}}),
l0);
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
EXPECT(p.get_output_shapes().back() == s2);
p.compile(migraphx::make_target("ref"));
migraphx::shape sresult{migraphx::shape::int32_type, {2, 2, 2}, {4, 2, 1}};
auto result = p.eval({}).back();
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
TEST_CASE(slice_var_inputs_static0)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int32_t> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), l0, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int32_t> start_data = {1};
std::vector<int32_t> end_data = {3};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_static1)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int32_t> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), l0, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int32_t> start_data = {-2};
std::vector<int32_t> end_data = {2831};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_static2)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::float_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int64_type, {3}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
auto axes = mm->add_parameter("axes", s1);
mm->add_instruction(migraphx::make_op("slice"), l0, starts, ends, axes);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int64_t> start_data = {0, 0, 0};
std::vector<int64_t> end_data = {2, 2, 2};
std::vector<int64_t> axes_data = {0, 1, 2};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
params["axes"] = migraphx::argument(s1, axes_data.data());
auto result = p.eval(params).back();
std::vector<float> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<float> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), input, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> start_data = {1};
std::vector<int> end_data = {3};
params["input"] = migraphx::argument(s2, input_data.data());
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(slice_dyn_test0)
{
// Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is
// too large
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 3}, {2, 2}, {3, 3}}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1, 2}}, {"starts", {0, 1}}, {"ends", {1, 6}}}), x);
migraphx::shape s2{migraphx::shape::int32_type, {{2, 3}, {1, 1}, {2, 2}}};
EXPECT(p.get_output_shapes().back() == s2);
p.compile(migraphx::make_target("ref"));
// the strides of sresult are those of the original shape, not
// reduced to sliced size.
migraphx::shape sresult{migraphx::shape::int32_type, {2, 1, 2}, {6, 3, 1}};
migraphx::shape input_fixed_shape{migraphx::shape::int32_type, {2, 2, 3}};
migraphx::parameter_map params;
std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
params["x"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 7, 8};
std::vector<int> results_vector(2 * 1 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
TEST_CASE(slice_dyn_test1)
{
// Slice all three dynamic dimensions
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 2}, {2, 2}, {3, 3}}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(
migraphx::make_op("slice",
{{"axes", {0, 1, 2}}, {"starts", {0, 0, 0}}, {"ends", {2, 2, 2}}}),
x);
migraphx::shape s2{migraphx::shape::int32_type, {{2, 2}, {2, 2}, {2, 2}}};
EXPECT(p.get_output_shapes().back() == s2);
p.compile(migraphx::make_target("ref"));
migraphx::shape sresult{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
migraphx::shape input_fixed_shape{migraphx::shape::int32_type, {2, 2, 3}};
migraphx::parameter_map params;
std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
params["x"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(softmax_simple_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> a = {0.25, 0.75};
std::vector<float> gold = {0.377541, 0.622459};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2}};
auto al = mm->add_literal(migraphx::literal{a_shape, a});
mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), al);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(softmax_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> a = {
-5.61869681e-01, 9.07827199e-01, 1.29255986e+00, 3.18533443e-02, -1.22183852e-03,
-2.83830553e-01, -1.03245842e+00, -9.28322077e-01, -8.82696748e-01, 1.11327164e-01,
-9.20038462e-01, 8.47388089e-01, 2.51734018e-01, 1.50563884e+00, 2.23056650e+00,
-6.17576987e-02, -1.00264274e-01, -6.10369384e-01, 1.17537189e+00, -2.51560897e-01,
-8.50333512e-01, -8.03578615e-01, -6.51194930e-01, -2.58137047e-01, 4.65528190e-01,
3.23284641e-02, -1.54700470e+00, 1.38096774e+00, 5.39869189e-01, -7.56884992e-01,
1.81503093e+00, -2.11269641e+00, 1.92466557e+00, 1.77230799e+00, 2.21660900e+00,
1.56777036e+00, -2.08995026e-03, 3.50566894e-01, -1.15042710e+00, -1.18577778e+00,
8.90633047e-01, -6.63949102e-02, 1.44661188e+00, 1.59215283e+00, -2.56262213e-01,
9.39079225e-01, 4.07298543e-02, 3.86590779e-01, 6.09607756e-01, 8.22331488e-01,
-2.82126725e-01, -9.49052632e-01, -4.24012303e-01, -5.32990396e-01, -3.18386006e+00,
3.27092171e-01, -1.33315325e+00, 3.62459183e-01, 3.74710828e-01, -1.30302286e+00,
1.79680198e-01, -4.51832324e-01, 4.34282750e-01, -7.09520102e-01, 6.20333970e-01,
-1.28712380e+00, 2.04130828e-01, -7.70607769e-01, 1.61889160e+00, -1.50951004e+00,
-4.10505563e-01, -3.56566496e-02, -1.29747534e+00, -1.49967879e-01, 7.77626812e-01,
-8.28408226e-02, 2.73412596e-02, 5.79780899e-03, 9.87900198e-02, -7.95276761e-01,
-1.38536084e+00, -6.63573861e-01, 3.89783204e-01, -1.30670881e+00, -7.62425125e-01,
-4.04883057e-01, 6.24344349e-01, 3.68128955e-01, -1.01577950e+00, -3.06715906e-01,
5.67961395e-01, 2.98198581e-01, -1.63613629e+00, -3.75131965e-01, -6.75393403e-01,
2.59172034e+00, 6.75538957e-01, 9.07939598e-02, 1.92257717e-01, -1.21592450e+00,
-2.73682117e-01, 1.25232983e+00, -1.39969170e+00, -1.91483587e-01, 2.57732719e-01,
3.10056299e-01, 1.41833842e+00, -1.81386679e-01, 3.92868072e-01, -8.14771175e-01,
2.02392387e+00, -9.42091495e-02, -3.77683818e-01, 2.05638766e+00, 2.93796062e-01,
-6.02131486e-01, 2.70461679e-01, -8.92358482e-01, 1.04388881e+00, 2.66154885e-01};
std::vector<float> gold = {
0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741,
0.18750034, 0.21905553, 0.27000085, 0.0547399, 0.56318235, 0.47422904, 0.78964758,
0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684, 0.18766427, 0.15753111,
0.07844277, 0.05120674, 0.36648798, 0.14637007, 0.13152322, 0.01560997, 0.29065287,
0.49196178, 0.10550152, 0.81890774, 0.06369215, 0.62972021, 0.74931765, 0.67285055,
0.35034987, 0.28612873, 0.31931475, 0.04220394, 0.16093165, 0.22390974, 0.11915915,
0.3115395, 0.35899726, 0.22190949, 0.57518375, 0.13888834, 0.7753762, 0.4642328,
0.57055861, 0.21954368, 0.34515455, 0.09486015, 0.40631217, 0.01842281, 0.48770609,
0.06652815, 0.36023033, 0.42343026, 0.24226256, 0.17348589, 0.44066274, 0.6865865,
0.17296699, 0.46923906, 0.06921105, 0.3570261, 0.4125829, 0.73165393, 0.15302512,
0.29499072, 0.33932695, 0.30852377, 0.40762195, 0.40170741, 0.36259529, 0.60848355,
0.42618036, 0.31721094, 0.02960522, 0.28256637, 0.24389413, 0.2725659, 0.10663581,
0.27622163, 0.28264219, 0.53652936, 0.09476089, 0.40890986, 0.34848392, 0.32572666,
0.53076893, 0.11529481, 0.29117745, 0.14625968, 0.8756339, 0.49818122, 0.10656087,
0.1813329, 0.17664003, 0.21410346, 0.80408043, 0.02315119, 0.27155462, 0.32804728,
0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809, 0.71028161, 0.29929739,
0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723,
0.42914796};
migraphx::shape a_shape{migraphx::shape::float_type, {5, 3, 4, 2}};
auto al = mm->add_literal(migraphx::literal{a_shape, a});
mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), al);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(120);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(softmax_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type,
{{1, 10}, {1, 3, {3}}, {4, 4}, {2, 2, {2}}}};
auto al = mm->add_parameter("a", a_shape);
mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), al);
p.compile(migraphx::make_target("ref"));
std::vector<float> a = {
-5.61869681e-01, 9.07827199e-01, 1.29255986e+00, 3.18533443e-02, -1.22183852e-03,
-2.83830553e-01, -1.03245842e+00, -9.28322077e-01, -8.82696748e-01, 1.11327164e-01,
-9.20038462e-01, 8.47388089e-01, 2.51734018e-01, 1.50563884e+00, 2.23056650e+00,
-6.17576987e-02, -1.00264274e-01, -6.10369384e-01, 1.17537189e+00, -2.51560897e-01,
-8.50333512e-01, -8.03578615e-01, -6.51194930e-01, -2.58137047e-01, 4.65528190e-01,
3.23284641e-02, -1.54700470e+00, 1.38096774e+00, 5.39869189e-01, -7.56884992e-01,
1.81503093e+00, -2.11269641e+00, 1.92466557e+00, 1.77230799e+00, 2.21660900e+00,
1.56777036e+00, -2.08995026e-03, 3.50566894e-01, -1.15042710e+00, -1.18577778e+00,
8.90633047e-01, -6.63949102e-02, 1.44661188e+00, 1.59215283e+00, -2.56262213e-01,
9.39079225e-01, 4.07298543e-02, 3.86590779e-01, 6.09607756e-01, 8.22331488e-01,
-2.82126725e-01, -9.49052632e-01, -4.24012303e-01, -5.32990396e-01, -3.18386006e+00,
3.27092171e-01, -1.33315325e+00, 3.62459183e-01, 3.74710828e-01, -1.30302286e+00,
1.79680198e-01, -4.51832324e-01, 4.34282750e-01, -7.09520102e-01, 6.20333970e-01,
-1.28712380e+00, 2.04130828e-01, -7.70607769e-01, 1.61889160e+00, -1.50951004e+00,
-4.10505563e-01, -3.56566496e-02, -1.29747534e+00, -1.49967879e-01, 7.77626812e-01,
-8.28408226e-02, 2.73412596e-02, 5.79780899e-03, 9.87900198e-02, -7.95276761e-01,
-1.38536084e+00, -6.63573861e-01, 3.89783204e-01, -1.30670881e+00, -7.62425125e-01,
-4.04883057e-01, 6.24344349e-01, 3.68128955e-01, -1.01577950e+00, -3.06715906e-01,
5.67961395e-01, 2.98198581e-01, -1.63613629e+00, -3.75131965e-01, -6.75393403e-01,
2.59172034e+00, 6.75538957e-01, 9.07939598e-02, 1.92257717e-01, -1.21592450e+00,
-2.73682117e-01, 1.25232983e+00, -1.39969170e+00, -1.91483587e-01, 2.57732719e-01,
3.10056299e-01, 1.41833842e+00, -1.81386679e-01, 3.92868072e-01, -8.14771175e-01,
2.02392387e+00, -9.42091495e-02, -3.77683818e-01, 2.05638766e+00, 2.93796062e-01,
-6.02131486e-01, 2.70461679e-01, -8.92358482e-01, 1.04388881e+00, 2.66154885e-01};
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {5, 3, 4, 2}};
params["a"] = migraphx::argument(input_fixed_shape, a.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(120);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741,
0.18750034, 0.21905553, 0.27000085, 0.0547399, 0.56318235, 0.47422904, 0.78964758,
0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684, 0.18766427, 0.15753111,
0.07844277, 0.05120674, 0.36648798, 0.14637007, 0.13152322, 0.01560997, 0.29065287,
0.49196178, 0.10550152, 0.81890774, 0.06369215, 0.62972021, 0.74931765, 0.67285055,
0.35034987, 0.28612873, 0.31931475, 0.04220394, 0.16093165, 0.22390974, 0.11915915,
0.3115395, 0.35899726, 0.22190949, 0.57518375, 0.13888834, 0.7753762, 0.4642328,
0.57055861, 0.21954368, 0.34515455, 0.09486015, 0.40631217, 0.01842281, 0.48770609,
0.06652815, 0.36023033, 0.42343026, 0.24226256, 0.17348589, 0.44066274, 0.6865865,
0.17296699, 0.46923906, 0.06921105, 0.3570261, 0.4125829, 0.73165393, 0.15302512,
0.29499072, 0.33932695, 0.30852377, 0.40762195, 0.40170741, 0.36259529, 0.60848355,
0.42618036, 0.31721094, 0.02960522, 0.28256637, 0.24389413, 0.2725659, 0.10663581,
0.27622163, 0.28264219, 0.53652936, 0.09476089, 0.40890986, 0.34848392, 0.32572666,
0.53076893, 0.11529481, 0.29117745, 0.14625968, 0.8756339, 0.49818122, 0.10656087,
0.1813329, 0.17664003, 0.21410346, 0.80408043, 0.02315119, 0.27155462, 0.32804728,
0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809, 0.71028161, 0.29929739,
0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723,
0.42914796};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(sqdiff_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}});
mm->add_instruction(migraphx::make_op("sqdiff"), l1, l2);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {4, 4, 4};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(sqdiff_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("sqdiff"), x, y);
p.compile(migraphx::make_target("ref"));
std::vector<float> x_data{-1, 0, 1};
std::vector<float> y_data{1, 2, 3};
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3}};
params0["x"] = migraphx::argument(input_fixed_shape0, x_data.data());
params0["y"] = migraphx::argument(input_fixed_shape0, y_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {4, 4, 4};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(sqrt_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {5}};
std::vector<float> data{1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184};
auto l = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("sqrt"), l);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sqrtf(n); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(sqrt_dynamic_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
std::vector<float> input_data{1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184};
mm->add_instruction(migraphx::make_op("sqrt"), input);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {5}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = input_data;
std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sqrtf(n); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,74 +21,74 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include "test.hpp"
TEST_CASE(argmax_test_nonstd_shape)
#include <test.hpp>
TEST_CASE(squeeze_test_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto dl = mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}}));
auto dl_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmax", {{"axis", -3}}), dl_trans);
auto p_uncompiled = p;
std::vector<float> data(4 * 3 * 3);
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
auto l0 = mm->add_literal(migraphx::literal{s1, data});
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
auto res_gold = p_uncompiled.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
auto result = p.eval({}).back();
EXPECT(result.get_shape() == s2);
}
TEST_CASE(argmin_test_nonstd_shape)
TEST_CASE(squeeze_test_2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto dl = mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}}));
auto dl_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), dl_trans);
auto p_uncompiled = p;
std::vector<float> data(4 * 3 * 3);
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
auto l0 = mm->add_literal(migraphx::literal{s1, data});
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {3}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
auto res_gold = p_uncompiled.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
auto result = p.eval({}).back();
EXPECT(result.get_shape() == s2);
}
TEST_CASE(isnan_broadcast_test)
TEST_CASE(squeeze_test_3)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {3}};
migraphx::shape s1{migraphx::shape::float_type, {3, 2}};
auto nan_val = std::numeric_limits<float>::quiet_NaN();
std::vector<float> data0 = {1.2, 5.2, nan_val};
auto l0 = mm->add_literal(migraphx::literal{s0, data0});
auto l1 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", s1.lens()}}), l0);
mm->add_instruction(migraphx::make_op("isnan"), l1);
std::vector<float> data(4 * 3 * 3);
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 3}};
auto l0 = mm->add_literal(migraphx::literal{s1, data});
mm->add_instruction(migraphx::make_op("squeeze"), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> correct = {0, 0, 0, 0, 1, 1};
EXPECT(migraphx::verify::verify_range(results_vector, correct));
EXPECT(result.get_shape() == s2);
}
TEST_CASE(squeeze_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {1, 1}, {3, 3}, {1, 1}, {3, 3}}};
auto p0 = mm->add_parameter("x", s1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), p0);
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data(4 * 3 * 3);
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
params0["x"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
EXPECT(result.get_shape() == s2);
}
TEST_CASE(squeeze_transpose_test)
......@@ -151,65 +151,3 @@ TEST_CASE(squeeze_slice_test)
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 3}});
EXPECT(result == expected_result);
}
TEST_CASE(unsqueeze_transpose_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0_trans);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 1, 3}});
EXPECT(result == expected_result);
}
TEST_CASE(unsqueeze_multibroadcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_brcst =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 3, 3}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0_brcst);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 4, 1, 3, 3}});
EXPECT(result == expected_result);
}
TEST_CASE(unsqueeze_slice_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4, 4}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {2}}, {"ends", {3}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l0_slice);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {2, 1, 3, 4, 1}});
EXPECT(result == expected_result);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment