Commit ade3a03c authored by Khalique's avatar Khalique
Browse files

Merge branch 'master' of https://github.com/ROCmSoftwarePlatform/MIGraph into leaky_relu

parents 7746373c 7d76401e
...@@ -87,6 +87,11 @@ constexpr void each_args(F f, Ts&&... xs) ...@@ -87,6 +87,11 @@ constexpr void each_args(F f, Ts&&... xs)
swallow{(f(std::forward<Ts>(xs)), 0)...}; swallow{(f(std::forward<Ts>(xs)), 0)...};
} }
template <class F>
constexpr void each_args(F)
{
}
/// Implements a fix-point combinator /// Implements a fix-point combinator
template <class R, class F> template <class R, class F>
detail::fix_f<R, F> fix(F f) detail::fix_f<R, F> fix(F f)
......
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/rank.hpp> #include <migraph/reflect.hpp>
#include <migraph/streamutils.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp> #include <migraph/auto_any_cast.hpp>
...@@ -54,11 +55,34 @@ namespace operation_stream { ...@@ -54,11 +55,34 @@ namespace operation_stream {
template <class T> template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{ {
return os << x.name(); os << x.name();
char delim = '[';
reflect_each(x, [&](auto& y, auto name) {
os << delim;
os << name << "=";
stream_write_value(os, y);
delim = ',';
});
if(delim == ',')
os << "]";
return os;
} }
} // namespace operation_stream } // namespace operation_stream
namespace operation_equal {
template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
if(x.name() != y.name())
return false;
const auto& yy = any_cast<T>(y);
return reflect_tie(x) == reflect_tie(yy);
}
} // namespace operation_equal
template <class T> template <class T>
auto compute_op(rank<1>, auto compute_op(rank<1>,
const T& x, const T& x,
...@@ -93,6 +117,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto ...@@ -93,6 +117,7 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
* shape compute_shape(const std::vector<shape>& input) const; * shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const; * argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* friend bool operator==(const operation & x,const operation & y) ;
* }; * };
* *
*/ */
...@@ -178,6 +203,12 @@ struct operation ...@@ -178,6 +203,12 @@ struct operation
return op.private_detail_te_get_handle().operator_shift_left(os); return op.private_detail_te_get_handle().operator_shift_left(os);
} }
friend bool operator==(const operation& x, const operation& y)
{
assert(x.private_detail_te_handle_mem_var);
return x.private_detail_te_get_handle().operator==(y);
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
...@@ -190,6 +221,7 @@ struct operation ...@@ -190,6 +221,7 @@ struct operation
virtual argument virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0; compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -242,6 +274,12 @@ struct operation ...@@ -242,6 +274,12 @@ struct operation
return os << private_detail_te_value; return os << private_detail_te_value;
} }
bool operator==(const operation& y) const override
{
using migraph::operation_equal::operator==;
return private_detail_te_value == y;
}
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
...@@ -307,6 +345,8 @@ inline const ValueType& any_cast(const operation& x) ...@@ -307,6 +345,8 @@ inline const ValueType& any_cast(const operation& x)
return *y; return *y;
} }
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }
#endif #endif
} // namespace migraph } // namespace migraph
......
...@@ -35,7 +35,12 @@ struct batch_norm_inference ...@@ -35,7 +35,12 @@ struct batch_norm_inference
bn_infer_mode_t bn_mode = spatial; bn_infer_mode_t bn_mode = spatial;
bool is_test = false; template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.epsilon, "epsilon"), f(self.momentum, "momentum"), f(self.bn_mode, "bn_mode"));
}
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -56,6 +61,16 @@ struct convolution ...@@ -56,6 +61,16 @@ struct convolution
valid valid
}; };
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"));
}
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -110,16 +125,6 @@ struct convolution ...@@ -110,16 +125,6 @@ struct convolution
MIGRAPH_THROW("Invalid padding mode"); MIGRAPH_THROW("Invalid padding mode");
} }
} }
friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{
os << op.name() << "[";
os << "padding={" << stream_range(op.padding) << "}, ";
os << "stride={" << stream_range(op.stride) << "}, ";
os << "dilation={" << stream_range(op.dilation) << "}";
os << "]";
return os;
}
}; };
struct im2col struct im2col
...@@ -133,6 +138,16 @@ struct im2col ...@@ -133,6 +138,16 @@ struct im2col
same, same,
valid valid
}; };
padding_mode_t padding_mode = default_;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"));
}
std::string name() const { return "im2col"; } std::string name() const { return "im2col"; }
...@@ -168,6 +183,16 @@ struct pooling ...@@ -168,6 +183,16 @@ struct pooling
std::array<std::size_t, 2> padding = {{0, 0}}; std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}}; std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> lengths = {{1, 1}}; std::array<std::size_t, 2> lengths = {{1, 1}};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.mode, "mode"),
f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.lengths, "lengths"));
}
std::string name() const { return "pooling"; } std::string name() const { return "pooling"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
...@@ -196,16 +221,6 @@ struct pooling ...@@ -196,16 +221,6 @@ struct pooling
1)), 1)),
}}; }};
} }
friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{
os << op.name() << "[";
os << "padding={" << stream_range(op.padding) << "}, ";
os << "stride={" << stream_range(op.stride) << "}, ";
os << "lengths={" << stream_range(op.lengths) << "}";
os << "]";
return os;
}
}; };
struct activation struct activation
...@@ -243,6 +258,13 @@ struct leaky_relu ...@@ -243,6 +258,13 @@ struct leaky_relu
struct transpose struct transpose
{ {
std::vector<int64_t> dims; std::vector<int64_t> dims;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dims, "dims"));
}
std::string name() const { return "transpose"; } std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -274,13 +296,6 @@ struct transpose ...@@ -274,13 +296,6 @@ struct transpose
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
friend std::ostream& operator<<(std::ostream& os, const transpose& op)
{
os << op.name() << "[";
os << "dims={" << stream_range(op.dims) << "}";
os << "]";
return os;
}
}; };
struct contiguous struct contiguous
...@@ -304,6 +319,13 @@ struct slice ...@@ -304,6 +319,13 @@ struct slice
std::vector<int64_t> axes; std::vector<int64_t> axes;
std::vector<int64_t> starts; std::vector<int64_t> starts;
std::vector<int64_t> ends; std::vector<int64_t> ends;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
}
std::string name() const { return "slice"; } std::string name() const { return "slice"; }
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
...@@ -376,6 +398,13 @@ struct slice ...@@ -376,6 +398,13 @@ struct slice
struct squeeze struct squeeze
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -416,6 +445,13 @@ struct squeeze ...@@ -416,6 +445,13 @@ struct squeeze
struct unsqueeze struct unsqueeze
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -447,6 +483,13 @@ struct unsqueeze ...@@ -447,6 +483,13 @@ struct unsqueeze
struct reshape struct reshape
{ {
std::vector<int64_t> dims; std::vector<int64_t> dims;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dims, "dims"));
}
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -486,19 +529,19 @@ struct reshape ...@@ -486,19 +529,19 @@ struct reshape
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{
os << op.name() << "[";
os << "dims={" << stream_range(op.dims) << "}";
os << "]";
return os;
}
}; };
struct gemm struct gemm
{ {
float alpha = 1.0; float alpha = 1.0;
float beta = 0.0; float beta = 0.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
std::string name() const { return "gemm"; } std::string name() const { return "gemm"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -512,13 +555,6 @@ struct gemm ...@@ -512,13 +555,6 @@ struct gemm
to_string_range(b.lens()) + "}"); to_string_range(b.lens()) + "}");
return {t, {a.lens()[0], b.lens()[1]}}; return {t, {a.lens()[0], b.lens()[1]}};
} }
friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{
os << op.name() << "[";
os << "]";
return os;
}
}; };
struct unary struct unary
...@@ -603,6 +639,13 @@ struct softmax ...@@ -603,6 +639,13 @@ struct softmax
struct flatten struct flatten
{ {
uint64_t axis = 0; uint64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "flatten"; } std::string name() const { return "flatten"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -623,17 +666,17 @@ struct flatten ...@@ -623,17 +666,17 @@ struct flatten
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
friend std::ostream& operator<<(std::ostream& os, const flatten& op)
{
os << op.name() << "[";
os << "axis=" << op.axis;
os << "]";
return os;
}
}; };
struct broadcast struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
shape broadcast_shape; shape broadcast_shape;
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
...@@ -665,18 +708,10 @@ struct broadcast ...@@ -665,18 +708,10 @@ struct broadcast
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
friend std::ostream& operator<<(std::ostream& os, const broadcast& op)
{
os << op.name() << "[";
os << "axis=" << op.axis;
os << "]";
return os;
}
}; };
struct binary struct binary
{ {
uint64_t broadcast = 0;
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
...@@ -708,6 +743,13 @@ struct load ...@@ -708,6 +743,13 @@ struct load
{ {
shape s; shape s;
std::size_t offset = 0; std::size_t offset = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.s, "shape"), f(self.offset, "offset"));
}
std::string name() const { return "load"; } std::string name() const { return "load"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
...@@ -723,6 +765,13 @@ struct load ...@@ -723,6 +765,13 @@ struct load
struct outline struct outline
{ {
shape s; shape s;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.s, "shape"));
}
std::string name() const { return "outline"; } std::string name() const { return "outline"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
......
#ifndef MIGRAPH_GUARD_RTGLIB_REFLECT_HPP
#define MIGRAPH_GUARD_RTGLIB_REFLECT_HPP
#include <migraph/functional.hpp>
#include <migraph/rank.hpp>
#include <functional>
namespace migraph {
namespace detail {
template <class T, class Selector>
auto reflect_impl(rank<1>, T& x, Selector f) -> decltype(T::reflect(x, f))
{
return T::reflect(x, std::move(f));
}
template <class T, class Selector>
auto reflect_impl(rank<0>, T&, Selector)
{
return pack();
}
} // namespace detail
template <class T, class Selector>
auto reflect(T& x, Selector f)
{
return detail::reflect_impl(rank<1>{}, x, std::move(f));
}
template <class T>
auto reflect_tie(T& x)
{
return reflect(x, [](auto&& y, auto&&...) { return std::ref(y); })(
[](auto&&... xs) { return std::tie(xs.get()...); });
}
template <class T, class F>
void reflect_each(T& x, F f)
{
return reflect(x, [](auto&& y, auto... ys) { return pack(std::ref(y), ys...); })(
[&](auto&&... xs) {
each_args([&](auto p) { p([&](auto&& y, auto... ys) { f(y.get(), ys...); }); }, xs...);
});
}
} // namespace migraph
#endif
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <ostream> #include <ostream>
#include <algorithm> #include <algorithm>
#include <migraph/rank.hpp>
namespace migraph { namespace migraph {
...@@ -31,6 +32,28 @@ inline stream_range_container<Range> stream_range(const Range& r) ...@@ -31,6 +32,28 @@ inline stream_range_container<Range> stream_range(const Range& r)
return {r}; return {r};
} }
namespace detail {
template <class Range>
auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r)
-> decltype(r.begin(), r.end(), void())
{
os << stream_range(r);
}
template <class T>
void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
{
os << x;
}
} // namespace detail
template <class T>
void stream_write_value(std::ostream& os, const T& x)
{
detail::stream_write_value_impl(rank<1>{}, os, x);
}
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -255,7 +255,8 @@ struct onnx_parser ...@@ -255,7 +255,8 @@ struct onnx_parser
? op::batch_norm_inference::spatial ? op::batch_norm_inference::spatial
: op::batch_norm_inference::per_activation; : op::batch_norm_inference::per_activation;
} }
op::batch_norm_inference op{epsilon, momentum, bn_mode, is_test}; (void)is_test;
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
......
...@@ -192,7 +192,7 @@ void memory_coloring_impl::register_operand_alias() ...@@ -192,7 +192,7 @@ void memory_coloring_impl::register_operand_alias()
operand_alias["@param"] = -1; operand_alias["@param"] = -1;
operand_alias["transpose"] = 0; operand_alias["transpose"] = 0;
operand_alias["flatten"] = 0; operand_alias["flatten"] = 0;
operand_alias["broadcast"] = 1; operand_alias["broadcast"] = 0;
operand_alias["reshape"] = 0; operand_alias["reshape"] = 0;
operand_alias["pass"] = 0; operand_alias["pass"] = 0;
} }
......
...@@ -26,19 +26,18 @@ struct miopen_convolution ...@@ -26,19 +26,18 @@ struct miopen_convolution
shared<convolution_descriptor> cd; shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{}; miopenConvFwdAlgorithm_t algo{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
// TODO: Add algo
return op::convolution::reflect(self.op, f);
}
std::string name() const { return "gpu::convolution"; } std::string name() const { return "gpu::convolution"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
shape compile(context& ctx, const shape& output_shape, std::vector<instruction_ref> inputs); shape compile(context& ctx, const shape& output_shape, std::vector<instruction_ref> inputs);
friend std::ostream& operator<<(std::ostream& os, const miopen_convolution& self)
{
os << self.name() << "[";
os << self.op << ", ";
os << "algo=" << self.algo;
os << "]";
return os;
}
}; };
} // namespace gpu } // namespace gpu
......
...@@ -54,7 +54,7 @@ void pytorch_conv_bn_relu_maxpool() ...@@ -54,7 +54,7 @@ void pytorch_conv_bn_relu_maxpool()
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4); auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::batch_norm_inference{}, l5, p3, p4, p5, p6); auto l6 = p.add_instruction(migraph::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraph::op::activation{"relu"}, l6); auto l7 = p.add_instruction(migraph::op::activation{"relu"}, l6);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7); p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
......
...@@ -6,6 +6,11 @@ ...@@ -6,6 +6,11 @@
struct simple_operation struct simple_operation
{ {
template <class T, class F>
static auto reflect(T& x, F f)
{
return migraph::pack(f(x.data, "data"));
}
int data = 1; int data = 1;
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
migraph::shape compute_shape(const std::vector<migraph::shape>&) const migraph::shape compute_shape(const std::vector<migraph::shape>&) const
...@@ -19,7 +24,7 @@ struct simple_operation ...@@ -19,7 +24,7 @@ struct simple_operation
} }
friend std::ostream& operator<<(std::ostream& os, const simple_operation& op) friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
{ {
os << "[" << op.name() << "]"; os << op.name() << "[" << op.data << "]";
return os; return os;
} }
}; };
...@@ -44,9 +49,23 @@ void operation_copy_test() ...@@ -44,9 +49,23 @@ void operation_copy_test()
migraph::operation op1 = s; // NOLINT migraph::operation op1 = s; // NOLINT
migraph::operation op2 = op1; // NOLINT migraph::operation op2 = op1; // NOLINT
// cppcheck-suppress duplicateExpression // cppcheck-suppress duplicateExpression
EXPECT(s.name() == op1.name()); EXPECT(s == op1);
// cppcheck-suppress duplicateExpression // cppcheck-suppress duplicateExpression
EXPECT(op2.name() == op1.name()); EXPECT(op2 == op1);
}
void operation_equal_test()
{
simple_operation s{};
migraph::operation op1 = s;
s.data = 2;
migraph::operation op2 = op1; // NOLINT
migraph::operation op3 = s; // NOLINT
EXPECT(s != op1);
EXPECT(op2 == op1);
EXPECT(op3 != op2);
EXPECT(op3 != op1);
} }
struct not_operation struct not_operation
...@@ -70,7 +89,7 @@ void operation_print() ...@@ -70,7 +89,7 @@ void operation_print()
std::stringstream ss; std::stringstream ss;
ss << op; ss << op;
std::string s = ss.str(); std::string s = ss.str();
EXPECT(s == "[simple]"); EXPECT(s == "simple[1]");
} }
void operation_default_print() void operation_default_print()
...@@ -85,6 +104,7 @@ void operation_default_print() ...@@ -85,6 +104,7 @@ void operation_default_print()
int main() int main()
{ {
operation_copy_test(); operation_copy_test();
operation_equal_test();
operation_any_cast(); operation_any_cast();
operation_print(); operation_print();
operation_default_print(); operation_default_print();
......
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/rank.hpp> #include <migraph/reflect.hpp>
#include <migraph/streamutils.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp> #include <migraph/auto_any_cast.hpp>
...@@ -54,11 +55,34 @@ namespace operation_stream { ...@@ -54,11 +55,34 @@ namespace operation_stream {
template <class T> template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{ {
return os << x.name(); os << x.name();
char delim = '[';
reflect_each(x, [&](auto& y, auto name) {
os << delim;
os << name << "=";
stream_write_value(os, y);
delim = ',';
});
if(delim == ',')
os << "]";
return os;
} }
} // namespace operation_stream } // namespace operation_stream
namespace operation_equal {
template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
if(x.name() != y.name())
return false;
const auto& yy = any_cast<T>(y);
return reflect_tie(x) == reflect_tie(yy);
}
} // namespace operation_equal
template <class T> template <class T>
auto compute_op(rank<1>, auto compute_op(rank<1>,
const T& x, const T& x,
...@@ -85,13 +109,32 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto ...@@ -85,13 +109,32 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
} }
<% <%
interface('operation', interface(
virtual('name', returns='std::string', const=True), 'operation',
virtual('compute_shape', returns='shape', input='const std::vector<shape>&', const=True), virtual('name', returns = 'std::string', const = True),
virtual('compute', returns='argument', ctx='context&', output='const shape&', input='const std::vector<argument>&', const=True, default='compute_op'), virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='migraph::operation_stream::operator<<') virtual('compute',
) returns = 'argument',
%> ctx = 'context&',
output = 'const shape&',
input = 'const std::vector<argument>&',
const = True,
default = 'compute_op'),
friend('operator<<',
returns = 'std::ostream &',
os = 'std::ostream &',
op = 'const operation &',
using = 'migraph::operation_stream::operator<<'),
friend('operator==',
returns = 'bool',
x = 'const operation &',
y = 'const operation &',
using = 'migraph::operation_equal::operator==')) %>
inline bool operator!=(const operation& x, const operation& y)
{
return !(x == y);
}
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment