Commit d7dfe995 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into auto_contig_fix

parents c6ec6638 e3e00547
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -2735,67 +2735,6 @@ TEST_CASE(softsign_test) ...@@ -2735,67 +2735,6 @@ TEST_CASE(softsign_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(upsample_test)
{
migraphx::program p = migraphx::parse_onnx("upsample_test.onnx");
std::vector<float> x_data = {1, 2, 3, 4};
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, x_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(where_test)
{
migraphx::program p = migraphx::parse_onnx("where_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape c_shape{migraphx::shape::bool_type, {2}};
std::vector<int8_t> c_data = {1, 0};
migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 2}};
std::vector<float> x_data(8, 1.0f);
migraphx::shape y_shape{migraphx::shape::float_type, {2, 1, 2, 2}};
std::vector<float> y_data(8, 2.0f);
migraphx::parameter_map pp;
pp["c"] = migraphx::argument(c_shape, c_data.data());
pp["x"] = migraphx::argument(x_shape, x_data.data());
pp["y"] = migraphx::argument(y_shape, y_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::program& p) std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::program& p)
{ {
// input data filled with values 1 to nelements // input data filled with values 1 to nelements
...@@ -2921,4 +2860,131 @@ TEST_CASE(tril_row_one_test) ...@@ -2921,4 +2860,131 @@ TEST_CASE(tril_row_one_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(upsample_test)
{
migraphx::program p = migraphx::parse_onnx("upsample_test.onnx");
std::vector<float> x_data = {1, 2, 3, 4};
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, x_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(unique_dynamic_sorted_test)
{
migraphx::program p = migraphx::parse_onnx("unique_dynamic_sorted_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<float> x{2, 1, 1, 3, 4, 3};
std::vector<float> y_gold = {1, 2, 3, 4};
std::vector<size_t> y_idx_gold = {1, 0, 3, 4};
std::vector<size_t> x_idx_gold = {1, 0, 0, 2, 3, 2};
std::vector<size_t> y_ct_gold = {2, 1, 2, 1};
migraphx::shape s{migraphx::shape::float_type, {x.size()}};
migraphx::parameter_map pm;
pm["X"] = migraphx::argument(s, x.data());
auto result = p.eval(pm);
std::vector<float> yvec;
result[0].visit([&](auto out) { yvec.assign(out.begin(), out.end()); });
EXPECT(yvec == y_gold);
std::vector<size_t> y_idx_vec;
result[1].visit([&](auto out) { y_idx_vec.assign(out.begin(), out.end()); });
EXPECT(y_idx_vec == y_idx_gold);
std::vector<size_t> x_idx_vec;
result[2].visit([&](auto out) { x_idx_vec.assign(out.begin(), out.end()); });
EXPECT(x_idx_vec == x_idx_gold);
std::vector<size_t> y_ct_vec;
result[3].visit([&](auto out) { y_ct_vec.assign(out.begin(), out.end()); });
EXPECT(y_ct_vec == y_ct_gold);
}
TEST_CASE(unique_dynamic_unsorted_test)
{
migraphx::program p = migraphx::parse_onnx("unique_dynamic_unsorted_test.onnx");
p.compile(migraphx::make_target("ref"));
std::vector<float> x{2, 1, 1, 3, 4, 3};
std::vector<float> y_gold = {2, 1, 3, 4};
std::vector<size_t> y_idx_gold = {0, 1, 3, 4};
std::vector<size_t> x_idx_gold = {0, 1, 1, 2, 3, 2};
std::vector<size_t> y_ct_gold = {1, 2, 2, 1};
migraphx::shape s{migraphx::shape::float_type, {x.size()}};
migraphx::parameter_map pm;
pm["X"] = migraphx::argument(s, x.data());
auto result = p.eval(pm);
std::vector<float> yvec;
result[0].visit([&](auto out) { yvec.assign(out.begin(), out.end()); });
EXPECT(yvec == y_gold);
std::vector<size_t> y_idx_vec;
result[1].visit([&](auto out) { y_idx_vec.assign(out.begin(), out.end()); });
EXPECT(y_idx_vec == y_idx_gold);
std::vector<size_t> x_idx_vec;
result[2].visit([&](auto out) { x_idx_vec.assign(out.begin(), out.end()); });
EXPECT(x_idx_vec == x_idx_gold);
std::vector<size_t> y_ct_vec;
result[3].visit([&](auto out) { y_ct_vec.assign(out.begin(), out.end()); });
EXPECT(y_ct_vec == y_ct_gold);
}
TEST_CASE(where_test)
{
migraphx::program p = migraphx::parse_onnx("where_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape c_shape{migraphx::shape::bool_type, {2}};
std::vector<int8_t> c_data = {1, 0};
migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 2}};
std::vector<float> x_data(8, 1.0f);
migraphx::shape y_shape{migraphx::shape::float_type, {2, 1, 2, 2}};
std::vector<float> y_data(8, 2.0f);
migraphx::parameter_map pp;
pp["c"] = migraphx::argument(c_shape, c_data.data());
pp["x"] = migraphx::argument(x_shape, x_data.data());
pp["y"] = migraphx::argument(y_shape, y_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f,
1.0f,
2.0f};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -4166,6 +4166,40 @@ TEST_CASE(test_squeeze_wrong_axis) ...@@ -4166,6 +4166,40 @@ TEST_CASE(test_squeeze_wrong_axis)
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1); throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
} }
TEST_CASE(test_unique_axis_invalid)
{
migraphx::shape x_shape{migraphx::shape::float_type, {10, 4, 3}};
throws_shape(migraphx::make_op("unique", {{"axis", -1}}), x_shape);
}
TEST_CASE(test_unique_axis_negative)
{
migraphx::shape x_shape{migraphx::shape::float_type, {10, 4, 3}};
std::vector<migraphx::shape::dynamic_dimension> y_dims{{1, 10}, {4, 4}, {3, 3}};
std::vector<migraphx::shape::dynamic_dimension> idx_dims{{1, 10}};
std::vector<migraphx::shape> y_dyn_shape{{migraphx::shape::float_type, y_dims},
{migraphx::shape::int64_type, idx_dims},
{migraphx::shape::int64_type, idx_dims},
{migraphx::shape::int64_type, idx_dims}};
expect_shape(y_dyn_shape, migraphx::make_op("unique", {{"axis", -3}}), x_shape);
}
TEST_CASE(test_unique_axis_none)
{
migraphx::shape x_shape{migraphx::shape::half_type, {10, 4, 3}};
std::vector<migraphx::shape::dynamic_dimension> y_dims{{1, 120}};
std::vector<migraphx::shape::dynamic_dimension> idx_dims{{1, 120}};
std::vector<migraphx::shape> y_dyn_shape{{migraphx::shape::half_type, y_dims},
{migraphx::shape::int64_type, idx_dims},
{migraphx::shape::int64_type, idx_dims},
{migraphx::shape::int64_type, idx_dims}};
expect_shape(y_dyn_shape, migraphx::make_op("unique"), x_shape);
}
TEST_CASE(test_unsqueeze) TEST_CASE(test_unsqueeze)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
......
This diff is collapsed.
...@@ -1017,6 +1017,40 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis) ...@@ -1017,6 +1017,40 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(concat_convert_fusion)
{
auto s = migraphx::shape{migraphx::shape::float_type, {64}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto xh = m1.add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
x);
auto yh = m1.add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
y);
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), xh, yh);
m1.add_instruction(pass_op{}, concat);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y);
auto concath = m2.add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::half_type)}}),
concat);
m2.add_instruction(pass_op{}, concath);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_div_const) TEST_CASE(simplify_div_const)
{ {
migraphx::module m1; migraphx::module m1;
......
...@@ -21,23 +21,36 @@ ...@@ -21,23 +21,36 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <limits>
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/half.hpp>
struct test_isnan_half : verify_program<test_isnan_half> struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2}}); migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
auto l0 = mm->add_literal(std::numeric_limits<migraphx::half>::quiet_NaN()); migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
x = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, l0); auto m2_elements = m2_shape.elements();
mm->add_instruction(migraphx::make_op("isnan"), x); auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
std::vector<float> zeros(m2_elements, 0);
auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), bias);
auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
mm->add_instruction(migraphx::make_op("relu"), gemm2);
return p; return p;
} }
}; };
...@@ -27,14 +27,19 @@ ...@@ -27,14 +27,19 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_abs : verify_program<test_abs> template <migraphx::shape::type_t DType>
struct test_abs : verify_program<test_abs<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("abs"), x); mm->add_instruction(migraphx::make_op("abs"), x);
return p; return p;
} }
}; };
template struct test_abs<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_abs<migraphx::shape::half_type>;
template struct test_abs<migraphx::shape::float_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_acos : verify_program<test_acos> template <migraphx::shape::type_t DType>
struct test_acos : verify_program<test_acos<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("acos"), x); mm->add_instruction(migraphx::make_op("acos"), x);
return p; return p;
} }
}; };
template struct test_acos<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_acos<migraphx::shape::half_type>;
template struct test_acos<migraphx::shape::float_type>;
...@@ -23,20 +23,23 @@ ...@@ -23,20 +23,23 @@
*/ */
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/literal.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_acosh : verify_program<test_acosh> template <typename CType>
struct test_acosh : verify_program<test_acosh<CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape::type_t dtype = migraphx::shape::get_type<CType>();
migraphx::shape s{dtype, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(1.1f); auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {1.1f}});
auto max_val = mm->add_literal(100.0f); auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {100.0f}});
min_val = min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val);
max_val = max_val =
...@@ -46,3 +49,7 @@ struct test_acosh : verify_program<test_acosh> ...@@ -46,3 +49,7 @@ struct test_acosh : verify_program<test_acosh>
return p; return p;
} }
}; };
template struct test_acosh<float>;
template struct test_acosh<migraphx::half>;
template struct test_acosh<migraphx::fp8::fp8e4m3fnuz>;
...@@ -27,16 +27,21 @@ ...@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_add : verify_program<test_add> template <migraphx::shape::type_t DType>
struct test_add : verify_program<test_add<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{DType, {8}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("add"), x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
return p; return p;
} }
}; };
template struct test_add<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_add<migraphx::shape::half_type>;
template struct test_add<migraphx::shape::float_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_asin : verify_program<test_asin> template <migraphx::shape::type_t DType>
struct test_asin : verify_program<test_asin<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("asin"), x); mm->add_instruction(migraphx::make_op("asin"), x);
return p; return p;
} }
}; };
template struct test_asin<migraphx::shape::float_type>;
template struct test_asin<migraphx::shape::half_type>;
template struct test_asin<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_asinh : verify_program<test_asinh> template <migraphx::shape::type_t DType>
struct test_asinh : verify_program<test_asinh<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("asinh"), x); mm->add_instruction(migraphx::make_op("asinh"), x);
return p; return p;
} }
}; };
template struct test_asinh<migraphx::shape::float_type>;
template struct test_asinh<migraphx::shape::half_type>;
template struct test_asinh<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_atan : verify_program<test_atan> template <migraphx::shape::type_t DType>
struct test_atan : verify_program<test_atan<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("atan"), x); mm->add_instruction(migraphx::make_op("atan"), x);
return p; return p;
} }
}; };
template struct test_atan<migraphx::shape::float_type>;
template struct test_atan<migraphx::shape::half_type>;
template struct test_atan<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -23,20 +23,24 @@ ...@@ -23,20 +23,24 @@
*/ */
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_atanh : verify_program<test_atanh> template <typename CType>
struct test_atanh : verify_program<test_atanh<CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape::type_t dtype = migraphx::shape::get_type<CType>();
migraphx::shape s{dtype, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(-0.95f); auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {-0.95f}});
auto max_val = mm->add_literal(0.95f); auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {0.95f}});
min_val = min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val);
max_val = max_val =
...@@ -46,3 +50,7 @@ struct test_atanh : verify_program<test_atanh> ...@@ -46,3 +50,7 @@ struct test_atanh : verify_program<test_atanh>
return p; return p;
} }
}; };
template struct test_atanh<float>;
template struct test_atanh<migraphx::half>;
template struct test_atanh<migraphx::fp8::fp8e4m3fnuz>;
...@@ -27,16 +27,21 @@ ...@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_ceil : verify_program<test_ceil> template <migraphx::shape::type_t DType>
struct test_ceil : verify_program<test_ceil<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; migraphx::shape s{DType, {2, 3, 4, 6}};
auto param = mm->add_parameter("x", s); auto param = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("ceil"), param); mm->add_instruction(migraphx::make_op("ceil"), param);
return p; return p;
}; };
}; };
template struct test_ceil<migraphx::shape::float_type>;
template struct test_ceil<migraphx::shape::half_type>;
template struct test_ceil<migraphx::shape::fp8e4m3fnuz_type>;
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment