Commit e8deff52 authored by mei-ye's avatar mei-ye
Browse files

resolve merge conflicts

parents edfee372 6c649c7b
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
#include <rob.hpp>
void simple_test() void simple_test()
{ {
...@@ -38,6 +39,11 @@ void incomplete_args() ...@@ -38,6 +39,11 @@ void incomplete_args()
EXPECT(bool{p.validate() == ins}); EXPECT(bool{p.validate() == ins});
} }
MIGRAPH_ROB(access_ins_arguments,
std::vector<migraph::instruction_ref>,
migraph::instruction,
arguments)
void invalid_args() void invalid_args()
{ {
migraph::program p; migraph::program p;
...@@ -45,7 +51,7 @@ void invalid_args() ...@@ -45,7 +51,7 @@ void invalid_args()
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto ins = p.add_instruction(sum_op{}, one, two); auto ins = p.add_instruction(sum_op{}, one, two);
ins->arguments.clear(); access_ins_arguments(*ins).clear();
EXPECT(bool{p.validate() == p.begin()}); EXPECT(bool{p.validate() == p.begin()});
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp> #include <migraph/auto_any_cast.hpp>
...@@ -27,13 +28,16 @@ struct operation ...@@ -27,13 +28,16 @@ struct operation
/// exception. /// exception.
shape compute_shape(const std::vector<shape>& input) const; shape compute_shape(const std::vector<shape>& input) const;
/** /**
* @brief This performs the operation's computation * @brief This performs the operation's computation.
*
* This method can be optional when the operation is only used as a placeholder to be lowered
* later on.
* *
* @param ctx This is the context created by the `target` during compilation. Implementations * @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class. * can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each * @param output This is the output shape. It is equivalent to running `compute_shape` with each
* `shape` of the `argument`. * `shape` of the `argument`.
* @param input This is the `argument` result from the previous instuction's computation. * @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be * @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape. * the same the `output` shape.
*/ */
...@@ -55,11 +59,29 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -55,11 +59,29 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream } // namespace operation_stream
template <class T>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPH_THROW("Not computable: " + name);
}
template <class T> template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return x.compute(auto_any_cast(ctx), output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
<% <%
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment