Commit 61edd67d authored by Sam Wu's avatar Sam Wu
Browse files

Merge branch 'develop' into doc-standard

parents a72c9e83 eafd55de
......@@ -634,8 +634,6 @@ def disabled_tests_onnx_1_11_0(backend_test):
# from OnnxBackendNodeModelTest
backend_test.exclude(r'test_roialign_aligned_false_cpu')
backend_test.exclude(r'test_roialign_aligned_true_cpu')
backend_test.exclude(r'test_scatternd_add_cpu')
backend_test.exclude(r'test_scatternd_multiply_cpu')
# errors
# from OnnxBackendNodeModelTest
......@@ -744,8 +742,6 @@ def disabled_tests_onnx_1_13_0(backend_test):
r'test_reduce_sum_square_negative_axes_keepdims_example_cpu')
backend_test.exclude(
r'test_reduce_sum_square_negative_axes_keepdims_random_cpu')
backend_test.exclude(r'test_scatternd_max_cpu')
backend_test.exclude(r'test_scatternd_min_cpu')
# errors
# from OnnxBackendNodeModelTest
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(scatternd_max_test_1)
{
// r=1, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
std::vector<float> upd_vec{9, 3, 1, 12};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4, 9, 6, 7, 12};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_max_test_2)
{
// r=2, q=2, k=2
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2}};
migraphx::shape is{itype, {2, 2}};
migraphx::shape us{dtype, {2}};
std::vector<float> data_vec{1, 2, 3, 4};
std::vector<int64_t> ind_vec{0, 0, 0, 1};
std::vector<float> upd_vec{5, 1};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 2, 3, 4};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_max_test_3)
{
// r=3, q=3, k=3
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2, 2}};
migraphx::shape is{itype, {2, 1, 3}};
migraphx::shape us{dtype, {2, 1}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 0, 0, 1, 1, 1};
std::vector<float> upd_vec{9, 1};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_max_test_4)
{
// r=3, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {4, 4, 4}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {2, 4, 4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 2};
std::vector<float> upd_vec{5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 5, 5, 5, 6, 6, 7, 8, 8, 7, 7, 7, 8, 8, 8, 8, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 2, 3, 3, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_max_test_duplicate_idx)
{
// r=3, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {4, 4, 4}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {2, 4, 4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 0};
std::vector<float> upd_vec{5, 5, 5, 5, 2, 2, 2, 2, 7, 7, 7, 7, 4, 4, 4, 4,
1, 1, 1, 1, 6, 6, 6, 6, 3, 3, 3, 3, 8, 8, 8, 8};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_max"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 5, 5, 5, 6, 6, 7, 8, 8, 7, 7, 7, 8, 8, 8, 8, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(scatternd_min_test_1)
{
// r=1, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{4, 3, 1, 7};
std::vector<float> upd_vec{9, 3, 1, 12};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_min"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 1, 3, 3, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_min_test_2)
{
// r=2, q=2, k=2
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2}};
migraphx::shape is{itype, {2, 2}};
migraphx::shape us{dtype, {2}};
std::vector<float> data_vec{1, 2, 3, 4};
std::vector<int64_t> ind_vec{0, 0, 0, 1};
std::vector<float> upd_vec{5, 1};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_min"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 1, 3, 4};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_min_test_3)
{
// r=3, q=3, k=3
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 2, 2}};
migraphx::shape is{itype, {2, 1, 3}};
migraphx::shape us{dtype, {2, 1}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 0, 0, 1, 1, 1};
std::vector<float> upd_vec{9, 1};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_min"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4, 5, 6, 7, 1};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_min_test_4)
{
// r=3, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {4, 4, 4}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {2, 4, 4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 2};
std::vector<float> upd_vec{5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_min"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4, 5, 6, 6, 6, 7, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 3, 3,
4, 4, 4, 4, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(scatternd_min_test_duplicate_idx)
{
// r=3, q=2, k=1
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {4, 4, 4}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {2, 4, 4}};
std::vector<float> data_vec{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int64_t> ind_vec{0, 0};
std::vector<float> upd_vec{5, 5, 5, 5, 2, 2, 2, 2, 7, 7, 7, 7, 4, 4, 4, 4,
1, 1, 1, 1, 6, 6, 6, 6, 3, 3, 3, 3, 8, 8, 8, 8};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, ind_vec});
auto updates = mm->add_literal(migraphx::literal{us, upd_vec});
auto scatternd =
mm->add_instruction(migraphx::make_op("scatternd_min"), data, indices, updates);
mm->add_return({scatternd});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <optional>
#include <test.hpp>
namespace {
migraphx::program
create_program(const migraphx::shape& data_shape, int64_t sorted, std::optional<int64_t> axis)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data = mm->add_parameter("X", data_shape);
auto op = axis ? migraphx::make_op("unique", {{"axis", *axis}, {"sorted", sorted}})
: migraphx::make_op("unique", {{"sorted", sorted}});
auto r = mm->add_instruction(op, data);
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), r);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), r);
auto r3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), r);
mm->add_return({r0, r1, r2, r3});
return p;
};
template <typename T>
auto run_program(T& data,
const migraphx::shape& data_shape,
int sorted,
std::optional<int64_t> axis = std::nullopt)
{
auto p = create_program(data_shape, sorted, axis);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(data_shape, data.data());
auto rets = p.eval(pp);
std::vector<typename std::remove_reference_t<decltype(data)>::value_type> y;
rets[0].visit([&](auto v) { y.assign(v.begin(), v.end()); });
std::vector<int64_t> y_idx;
rets[1].visit([&](auto v) { y_idx.assign(v.begin(), v.end()); });
std::vector<int64_t> x_rev_idx;
rets[2].visit([&](auto v) { x_rev_idx.assign(v.begin(), v.end()); });
std::vector<int64_t> y_ct;
rets[3].visit([&](auto v) { y_ct.assign(v.begin(), v.end()); });
return std::make_tuple(y, y_idx, x_rev_idx, y_ct);
}
} // namespace
// sorted. single entry
TEST_CASE(unique_sorted_single_entry_test)
{
std::vector<int> data = {2};
int64_t axis = 0;
int64_t sorted = 1;
std::vector<size_t> lens = {1};
migraphx::shape data_shape{migraphx::shape::int32_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<int> gold_val = {2};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {1};
EXPECT(ct == gold_ct);
}
// unsorted. single entry
TEST_CASE(unique_unsorted_single_entry_test)
{
std::vector<float> data = {3.33};
int64_t axis = -1;
int64_t sorted = 0;
std::vector<size_t> lens = {1};
migraphx::shape data_shape{migraphx::shape::float_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<float> gold_val = {3.33};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {1};
EXPECT(ct == gold_ct);
}
// case 2 sorted. all unique input..
TEST_CASE(unique_sorted_all_unique_test)
{
std::vector<float> data = {2.1, 2.3, 2.4, 2.5, 1.9};
int64_t axis = 0;
int64_t sorted = 1;
std::vector<size_t> lens = {5};
migraphx::shape data_shape{migraphx::shape::float_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<float> gold_val = {1.9, 2.1, 2.3, 2.4, 2.5};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {4, 0, 1, 2, 3};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {1, 2, 3, 4, 0};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {1, 1, 1, 1, 1};
EXPECT(ct == gold_ct);
}
// case 3 unsorted. all unique input
TEST_CASE(unique_unsorted_all_unique_test)
{
std::vector<float> data = {2.1, 2.3, 2.4, 2.5, 1.9};
int64_t axis = 0;
int64_t sorted = 0;
std::vector<size_t> lens = {5};
migraphx::shape data_shape{migraphx::shape::float_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<float> gold_val = {2.1, 2.3, 2.4, 2.5, 1.9};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0, 1, 2, 3, 4};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0, 1, 2, 3, 4};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {1, 1, 1, 1, 1};
EXPECT(ct == gold_ct);
}
// case 4 sorted (with dup entries)
TEST_CASE(unique_sorted_dupes_test)
{
std::vector<double> data = {2.1, 2.3, 2.4, 2.5, 1.9, 2.5, 2.3, 2.5};
int64_t axis = 0;
int64_t sorted = 1;
std::vector<size_t> lens = {8};
migraphx::shape data_shape{migraphx::shape::double_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<double> gold_val = {1.9, 2.1, 2.3, 2.4, 2.5};
EXPECT(y == gold_val);
std::vector<int64_t> gold_ct = {1, 1, 2, 1, 3};
EXPECT(ct == gold_ct);
}
// case 5 unsorted (with dup entries)
TEST_CASE(unique_unsorted_dupes_test)
{
std::vector<float> data = {2.1, 2.3, 2.4, 2.5, 1.9, 2.5, 2.3, 2.1};
int64_t axis = -1;
int64_t sorted = 0;
std::vector<size_t> lens = {8};
migraphx::shape data_shape{migraphx::shape::float_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<float> gold_val = {2.1, 2.3, 2.4, 2.5, 1.9};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0, 1, 2, 3, 4};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0, 1, 2, 3, 4, 3, 1, 0};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {2, 2, 1, 2, 1};
EXPECT(ct == gold_ct);
}
TEST_CASE(unique_3D_no_axis_test)
{
// sorted 3D (with dup entries). no axis
int sorted = 1;
std::vector<double> data_3d = {2.1, 2.3, 2.4, 2.5, 1.9, 2.5, 2.3, 2.5};
std::vector<size_t> lens = {2, 2, 2}; // 3D data. type double
migraphx::shape data_shape{migraphx::shape::double_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data_3d, data_shape, sorted);
std::vector<double> gold_val = {1.9, 2.1, 2.3, 2.4, 2.5};
EXPECT(y == gold_val);
std::vector<int64_t> gold_ct = {1, 1, 2, 1, 3};
EXPECT(ct == gold_ct);
}
TEST_CASE(unique_3D_no_axis_unsorted_test)
// unsorted 3D (with dup entries). no axis
{
int sorted = 0;
std::vector<float> data = {2.1, 2.3, 2.4, 2.5, 1.9, 2.5, 2.3, 2.1};
std::vector<size_t> lens = {2, 1, 4}; // 3D data. type float
migraphx::shape data_shape{migraphx::shape::float_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted);
std::vector<float> gold_val = {2.1, 2.3, 2.4, 2.5, 1.9};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0, 1, 2, 3, 4};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0, 1, 2, 3, 4, 3, 1, 0};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {2, 2, 1, 2, 1};
EXPECT(ct == gold_ct);
}
// unique integer sub-tensors: sorted (with dup entries)
TEST_CASE(unique_subtensors_sorted_test)
{
/*
input_X = [[1, 0, 0], [1, 0, 0], [2, 3, 4]]
attribute_sorted = 1
attribute_axis = 0
output_Y = [[1, 0, 0], [2, 3, 4]]
output_indices = [0, 2]
output_inverse_indices = [0, 0, 1]
output_counts = [2, 1]
*/
int axis = 0;
int sorted = 1;
std::vector<int32_t> data = {1, 0, 0, 1, 0, 0, 2, 3, 4};
std::vector<size_t> lens = {3, 3};
migraphx::shape data_shape{migraphx::shape::int32_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<int32_t> gold_val = {1, 0, 0, 2, 3, 4};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0, 2};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0, 0, 1};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {2, 1};
EXPECT(ct == gold_ct);
}
// unique integer sub-tensors: un-sorted (with dup entries)
TEST_CASE(unique_subtensors_neg_axis_test)
{
/*
input_X = [[1, 0, 0], [1, 0, 0], [2, 3, 4]]
attribute_sorted = 0
attribute_axis = 0
output_Y = [[1, 0, 0], [2, 3, 4]]
output_indices = [0, 2]
output_inverse_indices = [0, 0, 1]
output_counts = [2, 1]
*/
int axis = -2; // == 0
int sorted = 0;
std::vector<int32_t> data = {1, 0, 0, 1, 0, 0, 2, 3, 4};
std::vector<size_t> lens = {3, 3};
migraphx::shape data_shape{migraphx::shape::int32_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<int32_t> gold_val = {1, 0, 0, 2, 3, 4};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0, 2};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0, 0, 1};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {2, 1};
EXPECT(ct == gold_ct);
}
// unique float sub-tensors: sorted (with dup entries) axis = 0
TEST_CASE(unique_subtensors_zero_axis_test)
{
/*
input_x = [[[1., 1.], [0., 1.], [2., 1.], [0., 1.]],
[[1., 1.], [0., 1.], [2., 1.], [0., 1.]]]
attribute_sorted = 1
attribute_axis = 0
*/
int axis = 0;
int sorted = 1;
std::vector<float> data = {1., 1., 0., 1., 2., 1., 0., 1., 1., 1., 0., 1., 2., 1., 0., 1.};
std::vector<size_t> lens = {2, 4, 2};
migraphx::shape data_shape{migraphx::shape::float_type, lens};
const auto& [y, idx, x_rev, ct] = run_program(data, data_shape, sorted, axis);
std::vector<float> gold_val = {1., 1., 0., 1., 2., 1., 0., 1.};
EXPECT(y == gold_val);
std::vector<int64_t> gold_y_idx = {0};
EXPECT(idx == gold_y_idx);
std::vector<int64_t> gold_x_rev = {0, 0};
EXPECT(x_rev == gold_x_rev);
std::vector<int64_t> gold_ct = {2};
EXPECT(ct == gold_ct);
}
......@@ -1017,6 +1017,40 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
EXPECT(m1 == m2);
}
TEST_CASE(concat_convert_fusion)
{
auto s = migraphx::shape{migraphx::shape::float_type, {64}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto xh = m1.add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
x);
auto yh = m1.add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
y);
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), xh, yh);
m1.add_instruction(pass_op{}, concat);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
auto concath = m2.add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
concat);
m2.add_instruction(pass_op{}, concath);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_div_const)
{
migraphx::module m1;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
auto m2_elements = m2_shape.elements();
auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
std::vector<float> zeros(m2_elements, 0);
auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), bias);
auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
mm->add_instruction(migraphx::make_op("relu"), gemm2);
return p;
}
};
......@@ -27,14 +27,19 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_abs : verify_program<test_abs>
template <migraphx::shape::type_t DType>
struct test_abs : verify_program<test_abs<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto x = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("abs"), x);
return p;
}
};
template struct test_abs<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_abs<migraphx::shape::half_type>;
template struct test_abs<migraphx::shape::float_type>;
......@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_acos : verify_program<test_acos>
template <migraphx::shape::type_t DType>
struct test_acos : verify_program<test_acos<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}};
migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("acos"), x);
return p;
}
};
template struct test_acos<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_acos<migraphx::shape::half_type>;
template struct test_acos<migraphx::shape::float_type>;
......@@ -23,20 +23,23 @@
*/
#include "verify_program.hpp"
#include <migraphx/literal.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_acosh : verify_program<test_acosh>
template <typename CType>
struct test_acosh : verify_program<test_acosh<CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}};
auto* mm = p.get_main_module();
migraphx::shape::type_t dtype = migraphx::shape::get_type<CType>();
migraphx::shape s{dtype, {16}};
auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(1.1f);
auto max_val = mm->add_literal(100.0f);
auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {1.1f}});
auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {100.0f}});
min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val);
max_val =
......@@ -46,3 +49,7 @@ struct test_acosh : verify_program<test_acosh>
return p;
}
};
template struct test_acosh<float>;
template struct test_acosh<migraphx::half>;
template struct test_acosh<migraphx::fp8::fp8e4m3fnuz>;
......@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_add : verify_program<test_add>
template <migraphx::shape::type_t DType>
struct test_add : verify_program<test_add<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
migraphx::shape s{DType, {8}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("add"), x, y);
return p;
}
};
template struct test_add<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_add<migraphx::shape::half_type>;
template struct test_add<migraphx::shape::float_type>;
......@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_asin : verify_program<test_asin>
template <migraphx::shape::type_t DType>
struct test_asin : verify_program<test_asin<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}};
migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("asin"), x);
return p;
}
};
template struct test_asin<migraphx::shape::float_type>;
template struct test_asin<migraphx::shape::half_type>;
template struct test_asin<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_asinh : verify_program<test_asinh>
template <migraphx::shape::type_t DType>
struct test_asinh : verify_program<test_asinh<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}};
migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("asinh"), x);
return p;
}
};
template struct test_asinh<migraphx::shape::float_type>;
template struct test_asinh<migraphx::shape::half_type>;
template struct test_asinh<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_atan : verify_program<test_atan>
template <migraphx::shape::type_t DType>
struct test_atan : verify_program<test_atan<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}};
migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("atan"), x);
return p;
}
};
template struct test_atan<migraphx::shape::float_type>;
template struct test_atan<migraphx::shape::half_type>;
template struct test_atan<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -23,20 +23,24 @@
*/
#include "verify_program.hpp"
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_atanh : verify_program<test_atanh>
template <typename CType>
struct test_atanh : verify_program<test_atanh<CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}};
auto* mm = p.get_main_module();
migraphx::shape::type_t dtype = migraphx::shape::get_type<CType>();
migraphx::shape s{dtype, {16}};
auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(-0.95f);
auto max_val = mm->add_literal(0.95f);
auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {-0.95f}});
auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {0.95f}});
min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val);
max_val =
......@@ -46,3 +50,7 @@ struct test_atanh : verify_program<test_atanh>
return p;
}
};
template struct test_atanh<float>;
template struct test_atanh<migraphx::half>;
template struct test_atanh<migraphx::fp8::fp8e4m3fnuz>;
......@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_ceil : verify_program<test_ceil>
template <migraphx::shape::type_t DType>
struct test_ceil : verify_program<test_ceil<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}};
migraphx::shape s{DType, {2, 3, 4, 6}};
auto param = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("ceil"), param);
return p;
};
};
template struct test_ceil<migraphx::shape::float_type>;
template struct test_ceil<migraphx::shape::half_type>;
template struct test_ceil<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,16 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_concat_axis_0 : verify_program<test_concat_axis_0>
template <migraphx::shape::type_t DType>
struct test_concat_axis_0 : verify_program<test_concat_axis_0<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
int axis = 0;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {1, 2}};
migraphx::shape s0{DType, {2, 2}};
migraphx::shape s1{DType, {3, 2}};
migraphx::shape s2{DType, {1, 2}};
auto l0 = mm->add_parameter("x", s0);
auto l1 = mm->add_parameter("y", s1);
auto l2 = mm->add_parameter("z", s2);
......@@ -44,3 +45,8 @@ struct test_concat_axis_0 : verify_program<test_concat_axis_0>
return p;
}
};
template struct test_concat_axis_0<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_concat_axis_0<migraphx::shape::half_type>;
template struct test_concat_axis_0<migraphx::shape::float_type>;
template struct test_concat_axis_0<migraphx::shape::int32_type>;
......@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_cos : verify_program<test_cos>
template <migraphx::shape::type_t DType>
struct test_cos : verify_program<test_cos<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {8}};
migraphx::shape s{DType, {8}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("cos"), x);
return p;
}
};
template struct test_cos<migraphx::shape::float_type>;
template struct test_cos<migraphx::shape::half_type>;
template struct test_cos<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_cosh : verify_program<test_cosh>
template <migraphx::shape::type_t DType>
struct test_cosh : verify_program<test_cosh<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}};
migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("cosh"), x);
return p;
}
};
template struct test_cosh<migraphx::shape::float_type>;
template struct test_cosh<migraphx::shape::half_type>;
template struct test_cosh<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_erf : verify_program<test_erf>
template <migraphx::shape::type_t DType>
struct test_erf : verify_program<test_erf<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
migraphx::shape s{DType, {2, 3, 4, 6}};
auto param = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("erf"), param);
return p;
}
};
template struct test_erf<migraphx::shape::float_type>;
template struct test_erf<migraphx::shape::half_type>;
template struct test_erf<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_exp : verify_program<test_exp>
template <migraphx::shape::type_t DType>
struct test_exp : verify_program<test_exp<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}};
migraphx::shape s{DType, {6}};
auto x = mm->add_instruction(migraphx::make_op("abs"), mm->add_parameter("x", s));
mm->add_instruction(migraphx::make_op("exp"), x);
return p;
}
};
template struct test_exp<migraphx::shape::float_type>;
template struct test_exp<migraphx::shape::half_type>;
template struct test_exp<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment