Commit 8ff4b151 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Split mod operation into fmod & mod equivalents

Since onnx's Mod operation changes behavior based on whether the fmod flag is set, functionality is now split to mirror python's fmod() functionality.

For the integer mod case, I had to use a componsition of std::fmod() so that floating and integral types are all handled while also perserving sign to be identital to the python numpy::mod() case.
parent 76035525
......@@ -133,6 +133,7 @@ register_migraphx_ops(
exp
flatten
floor
fmod
gather
gathernd
get_tuple_elem
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct fmod : binary<fmod>
{
bool fmod_flag;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.fmod_flag, "fmod_flag"));
}
value attributes() const
{
auto a = base_attributes();
a["fmod_flag"] = fmod_flag;
return a;
}
std::string point_function() const { return "fmod(${0}, ${1})"; }
auto apply() const
{
return [&](auto x, auto y) { return std::fmod(x, y); };
}
fmod(bool fmod = true) : fmod_flag{fmod} {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_MUL_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
......@@ -40,30 +40,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
template <typename T>
T mod_op(T x, T y)
{
return (x % y);
}
template <>
float mod_op<float>(float x, float y)
{
return std::fmod(x, y);
}
template <>
double mod_op<double>(double x, double y)
{
return std::fmod(x, y);
}
template <>
half_float::half mod_op<half_float::half>(half_float::half x, half_float::half y)
{
return half_float::fmod(x, y);
}
struct mod : binary<mod>
{
bool fmod_flag;
......@@ -81,41 +57,10 @@ struct mod : binary<mod>
return a;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, (*this)}.has(2).same_type().same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if((s0.type() == shape::float_type || s0.type() == shape::double_type ||
s0.type() == shape::half_type) &&
(fmod_flag == false))
{
MIGRAPHX_THROW("fmod must be true for floating data types");
}
if(s0 == s1 and s0.packed())
{
return s0;
}
else if(s0.packed() != s1.packed())
{
return s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
return {s0.type(), s0.lens()};
}
}
std::string point_function() const { return "mod"; }
auto apply() const
{
return [&](auto x, auto y) { return mod_op<decltype(x)>(x, y); };
return [&](auto x, auto y) { return std::fmod((std::fmod(x, y) + y), y); };
}
mod(bool fmod = false) : fmod_flag{fmod} {}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_mod : op_parser<parse_mod>
{
std::vector<op_desc> operators() const { return {{"Mod"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
int fmod_flag = 0;
if(contains(info.attributes, "fmod"))
{
fmod_flag = parser.parse_value(info.attributes.at("fmod")).at<int>();
}
if(fmod_flag == 1)
{
return info.add_common_op("fmod", args[0], args[1]);
}
else
{
return info.add_common_op("mod", args[0], args[1]);
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -3030,21 +3030,58 @@ TEST_CASE(min_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(mod_test)
TEST_CASE(fmod_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {3}};
auto l0 = mm->add_literal(migraphx::literal{s, {-7, 8, 3}});
auto l0 = mm->add_literal(migraphx::literal{s, {-7, 8, -3}});
auto l1 = mm->add_literal(migraphx::literal{s, {2, 4, 6}});
auto l2 = mm->add_literal(migraphx::literal{s, {7, 5, 9}});
auto curr_mod = mm->add_instruction(migraphx::make_op("fmod"), l0, l1);
mm->add_instruction(migraphx::make_op("fmod"), curr_mod, l2);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-1, 0, -3};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(fmod_floatingPoint_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l0 = mm->add_literal(migraphx::literal{s, {-7.2f, 8.5f, -3.3f}});
auto l1 = mm->add_literal(migraphx::literal{s, {2.0f, 4.0f, 6.0f}});
auto l2 = mm->add_literal(migraphx::literal{s, {7.0f, 5.0f, 9.0f}});
auto curr_mod = mm->add_instruction(migraphx::make_op("fmod"), l0, l1);
mm->add_instruction(migraphx::make_op("fmod"), curr_mod, l2);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-1.2f, 0.5f, -3.3f};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(mod_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {3}};
auto l0 = mm->add_literal(migraphx::literal{s, {-3, 8, -7}});
auto l1 = mm->add_literal(migraphx::literal{s, {3, 3, 3}});
auto l2 = mm->add_literal(migraphx::literal{s, {10, 2, 9}});
auto curr_mod = mm->add_instruction(migraphx::make_op("mod"), l0, l1);
mm->add_instruction(migraphx::make_op("mod"), curr_mod, l2);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-1, 0, 3};
std::vector<float> gold{0, 0, 2};
EXPECT(migraphx::verify_range(results_vector, gold));
}
......@@ -3053,17 +3090,17 @@ TEST_CASE(mod_floatingPoint_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l0 = mm->add_literal(migraphx::literal{s, {7.2f, 8.5f, 3.3f}});
auto l1 = mm->add_literal(migraphx::literal{s, {2.0f, 4.0f, 6.0f}});
auto l2 = mm->add_literal(migraphx::literal{s, {7.0f, 5.0f, 9.0f}});
auto curr_mod = mm->add_instruction(migraphx::make_op("mod", {{"fmod_flag", true}}), l0, l1);
mm->add_instruction(migraphx::make_op("mod", {{"fmod_flag", true}}), curr_mod, l2);
auto l0 = mm->add_literal(migraphx::literal{s, {-3.0f, 8.5f, -7.0f}});
auto l1 = mm->add_literal(migraphx::literal{s, {2.0f, 3.0f, 3.0f}});
auto l2 = mm->add_literal(migraphx::literal{s, {3.0f, 3.0f, 4.0f}});
auto curr_mod = mm->add_instruction(migraphx::make_op("mod"), l0, l1);
mm->add_instruction(migraphx::make_op("mod"), curr_mod, l2);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.2f, 0.5f, 3.3f};
std::vector<float> gold{1.0f, 2.5f, 2.0f};
EXPECT(migraphx::verify_range(results_vector, gold));
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_fmod : verify_program<test_fmod>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("fmod"), x, y);
return p;
}
};
......@@ -36,7 +36,7 @@ struct test_mod : verify_program<test_mod>
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_instruction(migraphx::make_op("mod", {{"fmod_flag", true}}), x, y);
mm->add_instruction(migraphx::make_op("mod"), x, y);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment