Commit fb75dfaf authored by Paul's avatar Paul
Browse files

Only use no-cache on jenkins

parents e596eec2 f0604d78
#include <migraph/auto_contiguous.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
namespace migraph {
void auto_contiguous::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
shape s = ins->result;
if(not s.standard())
{
auto c = p.insert_instruction(std::next(ins), contiguous{}, ins);
p.replace_instruction(ins, c);
}
}
}
} // namespace migraph
#include <migraph/dead_code_elimination.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/functional.hpp>
namespace migraph {
void dead_code_elimination::apply(program& p) const
{
auto last = std::prev(p.end());
for(auto ins : iterator_for(p))
{
// Skip the first instruction, since we always process the previous
// instruction
if(ins == p.begin())
continue;
const auto i = std::prev(ins);
// Skip instruction with empty shape as output
if(i->result.elements() == 0)
continue;
// Skip the last instruction
if(i == last)
break;
fix([&](auto self, auto leaf) {
assert(p.has_instruction(leaf));
if(leaf->output.empty())
{
auto args = leaf->arguments;
leaf->clear_arguments();
p.move_instruction(leaf, p.end());
for(auto arg : args)
self(arg);
}
})(i);
}
p.remove_instructions(std::next(last), p.end());
}
} // namespace migraph
#include <migraph/generate.hpp>
namespace migraph {
argument generate_argument(shape s, std::mt19937::result_type seed)
{
argument result;
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
auto v = generate_tensor_data<type>(s, seed);
result = {s, [v]() mutable { return reinterpret_cast<char*>(v.data()); }};
});
return result;
}
literal generate_literal(shape s, std::mt19937::result_type seed)
{
literal result;
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
auto v = generate_tensor_data<type>(s, seed);
result = {s, v};
});
return result;
}
} // namespace migraph
#ifndef RTG_GUARD_RTGLIB_ARGUMENT_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_ARGUMENT_HPP
#define RTG_GUARD_RTGLIB_ARGUMENT_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_ARGUMENT_HPP
#include <rtg/shape.hpp> #include <migraph/shape.hpp>
#include <rtg/raw_data.hpp> #include <migraph/raw_data.hpp>
#include <functional> #include <functional>
namespace rtg { namespace migraph {
/** /**
* @brief Arguments passed to instructions * @brief Arguments passed to instructions
...@@ -39,16 +39,10 @@ struct argument : raw_data<argument> ...@@ -39,16 +39,10 @@ struct argument : raw_data<argument>
const shape& get_shape() const { return this->m_shape; } const shape& get_shape() const { return this->m_shape; }
template <class T>
T* cast() const
{
return reinterpret_cast<T*>(this->data());
}
private: private:
shape m_shape; shape m_shape;
}; };
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
#define MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
namespace migraph {
namespace detail {
template <class U>
void any_cast()
{
}
template <class T>
struct auto_any_caster
{
T& x;
template <class U>
operator U&()
{
return any_cast<U>(x);
}
operator T&() { return x; }
};
} // namespace detail
template <class T>
detail::auto_any_caster<T> auto_any_cast(T& x)
{
return {x};
}
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
#define MIGRAPH_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct auto_contiguous
{
std::string name() const { return "auto_contiguous"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
#ifndef RTG_GUARD_BUILTIN_HPP #ifndef MIGRAPH_GUARD_BUILTIN_HPP
#define RTG_GUARD_BUILTIN_HPP #define MIGRAPH_GUARD_BUILTIN_HPP
#include <rtg/operation.hpp> #include <migraph/context.hpp>
#include <rtg/errors.hpp> #include <migraph/errors.hpp>
#include <migraph/argument.hpp>
namespace rtg { namespace migraph {
namespace builtin { namespace builtin {
struct literal struct literal
{ {
std::string name() const { return "@literal"; } std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { MIGRAPH_THROW("builtin"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(context&, shape, std::vector<argument>) const { MIGRAPH_THROW("builtin"); }
}; };
struct outline struct outline
{ {
shape s; shape s;
std::string name() const { return "@outline"; } std::string name() const { return "@outline"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { return s; }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(context&, shape, std::vector<argument>) const { MIGRAPH_THROW("builtin"); }
}; };
struct param struct param
{ {
std::string parameter; std::string parameter;
std::string name() const { return "@param"; } std::string name() const { return "@param"; }
shape compute_shape(std::vector<shape>) const { RTG_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { MIGRAPH_THROW("builtin"); }
argument compute(shape, std::vector<argument>) const { RTG_THROW("builtin"); } argument compute(context&, shape, std::vector<argument>) const { MIGRAPH_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const param& op) friend std::ostream& operator<<(std::ostream& os, const param& op)
{ {
os << op.name() << ":" << op.parameter; os << op.name() << ":" << op.parameter;
...@@ -38,6 +39,6 @@ struct param ...@@ -38,6 +39,6 @@ struct param
} // namespace builtin } // namespace builtin
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP
#include <migraph/program.hpp>
namespace migraph {
template <class T>
struct check_context
{
struct op
{
std::string name() const { return "check_context"; }
shape compute_shape(std::vector<shape>) const { return {}; }
argument compute(context& ctx, shape, std::vector<argument>) const
{
T* x = any_cast<T>(&ctx);
if(x == nullptr)
MIGRAPH_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
return {};
}
};
std::string name() const { return "check_context"; }
void apply(program& p) const { p.insert_instruction(p.begin(), op{}); }
};
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#include <migraph/shape.hpp>
#include <algorithm>
namespace migraph {
struct check_shapes
{
const std::vector<shape>* shapes;
const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {}
template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name())
{
}
std::string prefix() const
{
if(name.empty())
return "";
else
return name + ": ";
}
const check_shapes& has(std::size_t n) const
{
assert(shapes != nullptr);
if(shapes->size() != n)
MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
" but given " + std::to_string(shapes->size()));
return *this;
}
const check_shapes& only_dims(std::size_t n) const
{
assert(shapes != nullptr);
if(!shapes->empty())
{
if(shapes->front().lens().size() != n)
MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
}
return *this;
}
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
MIGRAPH_THROW(prefix() + "Shapes do not match");
return *this;
}
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
MIGRAPH_THROW(prefix() + "Types do not match");
return *this;
}
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
MIGRAPH_THROW(prefix() + "Dimensions do not match");
return *this;
}
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.lens().size(); }))
MIGRAPH_THROW(prefix() + "Number of dimensions do not match");
return *this;
}
const check_shapes& standard() const
{
if(!this->all_of([](const shape& s) { return s.standard(); }))
MIGRAPH_THROW(prefix() + "Shapes are not in standard layout");
return *this;
}
const check_shapes& packed() const
{
if(!this->all_of([](const shape& s) { return s.packed(); }))
MIGRAPH_THROW(prefix() + "Shapes are not packed");
return *this;
}
const check_shapes& not_transposed() const
{
if(!this->all_of([](const shape& s) { return not s.transposed(); }))
MIGRAPH_THROW(prefix() + "Shapes are transposed");
return *this;
}
const check_shapes& not_broadcasted() const
{
// if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
// MIGRAPH_THROW(prefix() + "Shapes are broadcasted");
return *this;
}
template <class F>
bool same(F f) const
{
assert(shapes != nullptr);
if(shapes->empty())
return true;
auto&& key = f(shapes->front());
return this->all_of([&](const shape& s) { return f(s) == key; });
}
template <class Predicate>
bool all_of(Predicate p) const
{
assert(shapes != nullptr);
return std::all_of(shapes->begin(), shapes->end(), p);
}
};
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_CONTEXT_HPP
#define MIGRAPH_GUARD_CONTEXT_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace migraph {
#ifdef DOXYGEN
/// A context is used to store internal data for a `target`. A context is
/// constructed by a target during compilation and passed to the operations
/// during `eval`.
struct context
{
};
#else
/*
* Type-erased interface for:
*
* struct context
* {
* };
*
*/
struct context
{
// Constructors
context() = default;
template <typename PrivateDetailTypeErasedT>
context(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
context& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const context* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(context* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(context& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const context& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_DEAD_CODE_ELIMINATION_HPP
#define MIGRAPH_GUARD_RTGLIB_DEAD_CODE_ELIMINATION_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct dead_code_elimination
{
std::string name() const { return "dead_code_elimination"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_DFOR_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_DFOR_HPP
#define RTG_GUARD_RTGLIB_DFOR_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_DFOR_HPP
namespace rtg { namespace migraph {
// Multidimensional for loop // Multidimensional for loop
inline auto dfor() inline auto dfor()
...@@ -20,6 +20,6 @@ auto dfor(T x, Ts... xs) ...@@ -20,6 +20,6 @@ auto dfor(T x, Ts... xs)
}; };
} }
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_ERASE_HPP #ifndef MIGRAPH_GUARD_ERASE_HPP
#define RTG_GUARD_ERASE_HPP #define MIGRAPH_GUARD_ERASE_HPP
namespace rtg { #include <algorithm>
namespace migraph {
/** /**
* @brief Erase all elements from a container * @brief Erase all elements from a container
...@@ -29,6 +31,6 @@ auto erase_if(R&& r, P&& pred) ...@@ -29,6 +31,6 @@ auto erase_if(R&& r, P&& pred)
return r.erase(std::remove_if(r.begin(), r.end(), pred), r.end()); return r.erase(std::remove_if(r.begin(), r.end(), pred), r.end());
} }
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_ERRORS_HPP #ifndef MIGRAPH_GUARD_ERRORS_HPP
#define RTG_GUARD_ERRORS_HPP #define MIGRAPH_GUARD_ERRORS_HPP
#include <exception> #include <exception>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
namespace rtg { namespace migraph {
/// Represents exceptions that can be thrown by rtglib /// Represents exceptions that can be thrown by migraphlib
struct exception : std::runtime_error struct exception : std::runtime_error
{ {
exception(std::string msg = "") : std::runtime_error(msg) {} exception(std::string msg = "") : std::runtime_error(msg) {}
...@@ -41,9 +41,9 @@ inline std::string make_source_context(const std::string& file, int line) ...@@ -41,9 +41,9 @@ inline std::string make_source_context(const std::string& file, int line)
/** /**
* @brief Throw an exception with context information * @brief Throw an exception with context information
*/ */
#define RTG_THROW(...) \ #define MIGRAPH_THROW(...) \
throw rtg::make_exception(rtg::make_source_context(__FILE__, __LINE__), __VA_ARGS__) throw migraph::make_exception(migraph::make_source_context(__FILE__, __LINE__), __VA_ARGS__)
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef MIGRAPH_GUARD_FALLTHROUGH_HPP
#define MIGRAPH_GUARD_FALLTHROUGH_HPP
namespace migraph {
#ifdef __clang__
#define MIGRAPH_FALLTHROUGH [[clang::fallthrough]]
#else
#define MIGRAPH_FALLTHROUGH
#endif
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_FLOAT_EQUAL_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
#define RTG_GUARD_RTGLIB_FLOAT_EQUAL_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <iso646.h> #include <iso646.h>
#endif #endif
namespace rtg { namespace migraph {
template <class... Ts> template <class... Ts>
using common_type = typename std::common_type<Ts...>::type; using common_type = typename std::common_type<Ts...>::type;
...@@ -32,6 +32,6 @@ struct float_equal_fn ...@@ -32,6 +32,6 @@ struct float_equal_fn
static constexpr float_equal_fn float_equal{}; static constexpr float_equal_fn float_equal{};
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_FUNCTIONAL_HPP
#define MIGRAPH_GUARD_RTGLIB_FUNCTIONAL_HPP
#include <utility>
namespace migraph {
namespace detail {
template <class R, class F>
struct fix_f
{
F f;
template <class... Ts>
R operator()(Ts&&... xs) const
{
return f(*this, std::forward<Ts>(xs)...);
}
};
} // namespace detail
/// Implements a fix-point combinator
template <class R, class F>
detail::fix_f<R, F> fix(F f)
{
return {f};
}
template <class F>
auto fix(F f)
{
return fix<void>(f);
}
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#include <migraph/argument.hpp>
#include <migraph/literal.hpp>
#include <random>
namespace migraph {
template <class T>
std::vector<T> generate_tensor_data(migraph::shape s, std::mt19937::result_type seed = 0)
{
std::vector<T> result(s.elements());
std::mt19937 engine{seed};
std::uniform_real_distribution<> dist;
std::generate(result.begin(), result.end(), [&] { return dist(engine); });
return result;
}
argument generate_argument(shape s, std::mt19937::result_type seed = 0);
literal generate_literal(shape s, std::mt19937::result_type seed = 0);
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_INSTRUCTION_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#define RTG_GUARD_RTGLIB_INSTRUCTION_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#include <rtg/literal.hpp> #include <migraph/literal.hpp>
#include <rtg/shape.hpp> #include <migraph/shape.hpp>
#include <rtg/builtin.hpp> #include <migraph/builtin.hpp>
#include <rtg/instruction_ref.hpp> #include <migraph/instruction_ref.hpp>
#include <rtg/erase.hpp> #include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <string> #include <string>
namespace rtg { namespace migraph {
shape compute_shape(operation op, std::vector<instruction_ref> args); shape compute_shape(operation op, std::vector<instruction_ref> args);
...@@ -37,23 +38,33 @@ struct instruction ...@@ -37,23 +38,33 @@ struct instruction
result = r; result = r;
for(auto&& ins : output) for(auto&& ins : output)
{ {
ins->replace(compute_shape(ins->op, ins->arguments)); assert(ins->op.name().front() != '@');
ins->recompute_shape();
} }
} }
} }
void recompute_shape() { replace(compute_shape(op, arguments)); }
void replace(std::vector<instruction_ref> args) void replace(std::vector<instruction_ref> args)
{ {
clear_arguments(); clear_arguments();
arguments = std::move(args); arguments = std::move(args);
} }
void replace_argument(instruction_ref old, instruction_ref new_ins)
{
std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this);
}
void clear_arguments() void clear_arguments()
{ {
for(auto&& arg : arguments) for(auto&& arg : arguments)
{ {
rtg::erase(arg->output, *this); arg->remove_output(*this);
} }
arguments.clear();
} }
friend bool operator==(const instruction& i, instruction_ref ref) friend bool operator==(const instruction& i, instruction_ref ref)
...@@ -61,25 +72,64 @@ struct instruction ...@@ -61,25 +72,64 @@ struct instruction
return std::addressof(i) == std::addressof(*ref); return std::addressof(i) == std::addressof(*ref);
} }
bool valid(instruction_ref start) const
{
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->output.begin(), i->output.end(), *this);
return self != i->output.end() &&
std::distance(start, i) < std::distance(start, *self);
});
}
bool valid() const bool valid() const
{ {
return std::all_of(output.begin(), shape computed;
output.end(), if(op.name() == "@literal")
[&](instruction_ref i) { {
return std::find(i->arguments.begin(), i->arguments.end(), *this) != computed = lit.get_shape();
i->arguments.end(); }
}) && else if(op.name() == "@param")
std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) { {
return std::find(i->output.begin(), i->output.end(), *this) != i->output.end(); computed = result;
}
else
{
try
{
computed = compute_shape(op, arguments);
}
catch(migraph::exception&)
{
return false;
}
}
return result == computed &&
std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->arguments.begin(), i->arguments.end(), *this) !=
i->arguments.end();
}); });
} }
shape get_shape() const { return result; }
friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; } friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
friend bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); } friend bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
friend bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); } friend bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
void add_output(instruction_ref ins)
{
if(std::find(output.begin(), output.end(), ins) == output.end())
output.push_back(ins);
}
template <class T>
void remove_output(const T& ins)
{
migraph::erase(output, ins);
}
operation op; operation op;
shape result; shape result;
std::vector<instruction_ref> output; std::vector<instruction_ref> output;
...@@ -90,7 +140,14 @@ struct instruction ...@@ -90,7 +140,14 @@ struct instruction
inline void backreference(instruction_ref ref) inline void backreference(instruction_ref ref)
{ {
for(auto&& arg : ref->arguments) for(auto&& arg : ref->arguments)
arg->output.push_back(ref); arg->add_output(ref);
}
inline void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins)
{
ins->replace_argument(old, new_ins);
backreference(ins);
ins->recompute_shape();
} }
// TODO: Move to a cpp file // TODO: Move to a cpp file
...@@ -103,6 +160,6 @@ inline shape compute_shape(operation op, std::vector<instruction_ref> args) ...@@ -103,6 +160,6 @@ inline shape compute_shape(operation op, std::vector<instruction_ref> args)
return op.compute_shape(shapes); return op.compute_shape(shapes);
} }
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_INSTRUCTION_REF_HPP #ifndef MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#define RTG_GUARD_INSTRUCTION_REF_HPP #define MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#include <list> #include <list>
namespace rtg { namespace migraph {
struct instruction; struct instruction;
using instruction_ref = std::list<instruction>::iterator; using instruction_ref = std::list<instruction>::iterator;
} // namespace rtg } // namespace migraph
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment