Commit 6d8d7d41 authored by Paul's avatar Paul
Browse files

Add equality operator

parent 68d69739
......@@ -73,8 +73,6 @@ struct program
argument eval(parameter_map params) const;
friend std::ostream& operator<<(std::ostream& os, const program& p);
bool has_instruction(instruction_ref ins) const;
instruction_ref begin();
......@@ -84,6 +82,13 @@ struct program
void compile(const target& t);
friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y)
{
return !(x == y);
}
private:
std::unique_ptr<program_impl> impl;
};
......
......@@ -66,7 +66,7 @@ inline std::string remove_prefix(std::string s, std::string prefix)
}
template <class Range>
inline std::string to_string(const Range& r)
inline std::string to_string_range(const Range& r)
{
std::stringstream ss;
if(!r.empty())
......@@ -77,6 +77,14 @@ inline std::string to_string(const Range& r)
return ss.str();
}
template<class T>
inline std::string to_string(const T& x)
{
std::stringstream ss;
ss << x;
return ss.str();
}
} // namespace migraph
#endif
......@@ -2,6 +2,7 @@
#include <migraph/stringutils.hpp>
#include <migraph/instruction.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
namespace migraph {
......@@ -190,6 +191,11 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
return result;
}
bool operator==(const program& x, const program& y)
{
return to_string(x) == to_string(y);
}
std::ostream& operator<<(std::ostream& os, const program& p)
{
std::unordered_map<const instruction*, std::string> names;
......
......@@ -126,8 +126,8 @@ bool operator!=(const shape& x, const shape& y) { return !(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x)
{
os << x.type_string() << ", ";
os << "{" << to_string(x.lens()) << "}, ";
os << "{" << to_string(x.strides()) << "}";
os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}";
return os;
}
......
......@@ -48,8 +48,9 @@ struct expression
decltype(auto) value() const { return Operator::call(lhs, rhs); };
};
// TODO: Remove rvalue references
template <class T, class U, class Operator>
expression<typename std::decay<T>::type, typename std::decay<U>::type, Operator>
expression<T, U, Operator>
make_expression(T&& rhs, U&& lhs, Operator)
{
return {std::forward<T>(rhs), std::forward<U>(lhs)};
......@@ -58,10 +59,11 @@ make_expression(T&& rhs, U&& lhs, Operator)
template <class T>
struct lhs_expression;
// TODO: Remove rvalue reference
template <class T>
lhs_expression<typename std::decay<T>::type> make_lhs_expression(T&& lhs)
lhs_expression<T> make_lhs_expression(T&& lhs)
{
return lhs_expression<typename std::decay<T>::type>{std::forward<T>(lhs)};
return lhs_expression<T>{std::forward<T>(lhs)};
}
template <class T>
......
#include <migraph/program.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/instruction.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
migraph::program create_program()
{
migraph::program p;
auto x = p.add_parameter("x", {migraph::shape::int64_type});
auto y = p.add_parameter("y", {migraph::shape::int64_type});
auto sum = p.add_instruction(sum_op{}, x, y);
auto one = p.add_literal(1);
p.add_instruction(sum_op{}, sum, one);
return p;
}
void program_equality()
{
migraph::program x = create_program();
migraph::program y = create_program();
EXPECT(x == y);
}
int main() {
program_equality();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment