"...include/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "e76bd7293eb27828cab07c35395d898d7cec8eeb"
Commit f8a12015 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

work in progress, seem to break my tests and parsing

parent 8ff4b151
...@@ -42,28 +42,11 @@ namespace op { ...@@ -42,28 +42,11 @@ namespace op {
struct fmod : binary<fmod> struct fmod : binary<fmod>
{ {
bool fmod_flag; std::string point_function() const { return "fmod"; }
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.fmod_flag, "fmod_flag"));
}
value attributes() const
{
auto a = base_attributes();
a["fmod_flag"] = fmod_flag;
return a;
}
std::string point_function() const { return "fmod(${0}, ${1})"; }
auto apply() const auto apply() const
{ {
return [&](auto x, auto y) { return std::fmod(x, y); }; return [](auto x, auto y) { return std::fmod(x, y); };
} }
fmod(bool fmod = true) : fmod_flag{fmod} {}
}; };
} // namespace op } // namespace op
......
...@@ -42,28 +42,11 @@ namespace op { ...@@ -42,28 +42,11 @@ namespace op {
struct mod : binary<mod> struct mod : binary<mod>
{ {
bool fmod_flag;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.fmod_flag, "fmod_flag"));
}
value attributes() const
{
auto a = base_attributes();
a["fmod_flag"] = fmod_flag;
return a;
}
std::string point_function() const { return "mod"; } std::string point_function() const { return "mod"; }
auto apply() const auto apply() const
{ {
return [&](auto x, auto y) { return std::fmod((std::fmod(x, y) + y), y); }; return [](auto x, auto y) { return std::fmod((std::fmod(x, y) + y), y); };
} }
mod(bool fmod = false) : fmod_flag{fmod} {}
}; };
} // namespace op } // namespace op
......
...@@ -57,6 +57,7 @@ ...@@ -57,6 +57,7 @@
#include <migraphx/op/exp.hpp> #include <migraphx/op/exp.hpp>
#include <migraphx/op/flatten.hpp> #include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp> #include <migraphx/op/floor.hpp>
#include <migraphx/op/fmod.hpp>
#include <migraphx/op/gather.hpp> #include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp> #include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp> #include <migraphx/op/get_tuple_elem.hpp>
...@@ -79,6 +80,7 @@ ...@@ -79,6 +80,7 @@
#include <migraphx/op/lstm.hpp> #include <migraphx/op/lstm.hpp>
#include <migraphx/op/max.hpp> #include <migraphx/op/max.hpp>
#include <migraphx/op/min.hpp> #include <migraphx/op/min.hpp>
#include <migraphx/op/mod.hpp>
#include <migraphx/op/mul.hpp> #include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp> #include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/neg.hpp> #include <migraphx/op/neg.hpp>
......
...@@ -39,21 +39,18 @@ struct parse_mod : op_parser<parse_mod> ...@@ -39,21 +39,18 @@ struct parse_mod : op_parser<parse_mod>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
int fmod_flag = 0; if(args.size() < 2)
MIGRAPHX_THROW("mod operators should have 2 operands");
std::string mod = "mod";
if(contains(info.attributes, "fmod")) if(contains(info.attributes, "fmod"))
{ {
fmod_flag = parser.parse_value(info.attributes.at("fmod")).at<int>(); if(parser.parse_value(info.attributes.at("fmod")).at<int>() == 1)
} {
mod = "fmod";
if(fmod_flag == 1) }
{
return info.add_common_op("fmod", args[0], args[1]);
}
else
{
return info.add_common_op("mod", args[0], args[1]);
} }
return info.add_common_op(mod, args[0], args[1]);
} }
}; };
......
...@@ -138,6 +138,8 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -138,6 +138,8 @@ struct pointwise_compiler : compiler<pointwise_compiler>
g.add_point_op("less", "migraphx::abs(${0} < ${1})"); g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})"); g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})"); g.add_point_op("not", "migraphx::abs(not ${0})");
g.add_point_op("mod", "migraphx::mod(${0}, ${1})");
g.add_point_op("fmod", "migraphx::fmod(${0}, ${1})");
// Add explict conversions // Add explict conversions
g.fresult([](const shape& s) { g.fresult([](const shape& s) {
return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; return "migraphx::convert<" + shape::cpp_type(s.type()) + ">";
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights # in the Software without rest?mod_riction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is # copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions: # furnished to do so, subject to the following conditions:
...@@ -3231,6 +3231,33 @@ def min_test(): ...@@ -3231,6 +3231,33 @@ def min_test():
return ([node], [a, b, c], [y]) return ([node], [a, b, c], [y])
@onnx_test
def mod_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2])
b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2])
y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2])
node = onnx.helper.make_node('Mod', inputs=['0', '1'], outputs=['2'])
return ([node], [a, b], [y])
@onnx_test
def mod_test_fmod():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2])
b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2])
y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2])
node = onnx.helper.make_node(
'Mod',
inputs=['0', '1'],
outputs=['2'],
fmod=1 #fmod flag = 1
)
return ([node], [a, b], [y])
@onnx_test @onnx_test
def multinomial_test(): def multinomial_test():
sample_size = 10 sample_size = 10
......
mod_test:M

0
12"Modmod_testZ
0

Z
1

b
2

B
\ No newline at end of file
 mod_test_fmod:_

0
12"Mod*
fmod mod_test_fmodZ
0

Z
1

b
2

B
\ No newline at end of file
...@@ -631,6 +631,55 @@ TEST_CASE(mean_integral_test) ...@@ -631,6 +631,55 @@ TEST_CASE(mean_integral_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(mod_test)
{
migraphx::program p = migraphx::parse_onnx("mod_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> data = {
3.0, 2.0, -3.0, 2.0, 9.0, 5.0, -9.0, 5.0, 0.0, 10.0, -0.0, 5.0, 6.0, 9.0};
migraphx::parameter_map p_map;
p_map["x"] = migraphx::argument(s, data.data());
auto result = p.eval(p_map).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0, 1.0, 4.0, 4.0, 0.0, 0.0, 6.0};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(mod_test_fmod)
{
migraphx::program p = migraphx::parse_onnx("mod_test_fmod.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> data = {
3.0, 2.0, -3.0, 2.0, 9.0, 5.0, -9.0, 5.0, 0.0, 10.0, -0.0, 5.0, 6.0, 9.0};
migraphx::parameter_map p_map;
p_map["x"] = migraphx::argument(s, data.data());
auto result = p.eval(p_map).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
for(auto i : result_vector)
{
std::cout << i;
std::cout << " ";
}
std::cout << std::endl;
std::vector<float> gold{1.0, -1.0, 4.0, -4.0, 0.0, 0.0, 6.0};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(nonzero_test) TEST_CASE(nonzero_test)
{ {
migraphx::program p = migraphx::parse_onnx("nonzero_dynamic_test.onnx"); migraphx::program p = migraphx::parse_onnx("nonzero_dynamic_test.onnx");
......
...@@ -162,6 +162,8 @@ def create_backend_test(testname=None, target_device=None): ...@@ -162,6 +162,8 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_MaxPool[1-9]d.*') backend_test.include(r'.*test_MaxPool[1-9]d.*')
backend_test.include(r'.*test_mean.*') backend_test.include(r'.*test_mean.*')
backend_test.include(r'.*test_min.*') backend_test.include(r'.*test_min.*')
backend_test.include(r'.*test_mod.*')
backend_test.include(r'.*test_fmod.*')
backend_test.include(r'.*test_mul.*') backend_test.include(r'.*test_mul.*')
backend_test.include(r'.*test_multinomial.*') backend_test.include(r'.*test_multinomial.*')
backend_test.include(r'.*test_Multinomial.*') backend_test.include(r'.*test_Multinomial.*')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment