Commit 3d9bef13 authored by Paul's avatar Paul
Browse files

Move functions to cpp file

parent b263425a
...@@ -7,6 +7,7 @@ add_library(migraph ...@@ -7,6 +7,7 @@ add_library(migraph
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
env.cpp env.cpp
generate.cpp generate.cpp
instruction.cpp
program.cpp program.cpp
shape.cpp shape.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
......
...@@ -3,10 +3,8 @@ ...@@ -3,10 +3,8 @@
#include <migraph/literal.hpp> #include <migraph/literal.hpp>
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/builtin.hpp>
#include <migraph/instruction_ref.hpp> #include <migraph/instruction_ref.hpp>
#include <migraph/operation.hpp> #include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -18,156 +16,61 @@ struct instruction ...@@ -18,156 +16,61 @@ struct instruction
{ {
instruction() {} instruction() {}
instruction(operation o, shape r, std::vector<instruction_ref> args) instruction(operation o, shape r, std::vector<instruction_ref> args);
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {}
void replace(const shape& r) instruction(literal l);
{
if(r != result)
{
result = r;
for(auto&& ins : output)
{
assert(ins->name().front() != '@');
ins->recompute_shape();
}
}
}
void recompute_shape() { replace(compute_shape(op, arguments)); } void replace(const shape& r);
void clear_arguments() void recompute_shape();
{
for(auto&& arg : arguments)
{
arg->remove_output(*this);
}
arguments.clear();
}
friend bool operator==(const instruction& i, instruction_ref ref)
{
return std::addressof(i) == std::addressof(*ref);
}
bool valid(instruction_ref start) const void clear_arguments();
{
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) { friend bool operator==(const instruction& i, instruction_ref ref);
auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
return self != i->outputs().end() &&
std::distance(start, i) < std::distance(start, *self);
});
}
bool valid() const bool valid(instruction_ref start) const;
{
shape computed;
if(op.name() == "@literal")
{
computed = lit.get_shape();
}
else if(op.name() == "@param")
{
computed = result;
}
else
{
try
{
computed = compute_shape(op, arguments);
}
catch(migraph::exception&)
{
return false;
}
}
return result == computed &&
std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->inputs().begin(), i->inputs().end(), *this) !=
i->inputs().end();
});
}
shape get_shape() const { return result; } bool valid() const;
const literal& get_literal() const
{ shape get_shape() const;
assert(op.name() == "@literal"); const literal& get_literal() const;
return lit;
}
const operation& get_operator() const { return op; } const operation& get_operator() const;
std::string name() const { return op.name(); } std::string name() const;
const std::vector<instruction_ref>& inputs() const { return arguments; } const std::vector<instruction_ref>& inputs() const;
const std::vector<instruction_ref>& outputs() const { return output; } const std::vector<instruction_ref>& outputs() const;
friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; } friend bool operator==(instruction_ref ref, const instruction& i);
friend bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); } friend bool operator!=(const instruction& i, instruction_ref ref);
friend bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); } friend bool operator!=(instruction_ref ref, const instruction& i);
void add_output(instruction_ref ins) void add_output(instruction_ref ins);
{
if(std::find(output.begin(), output.end(), ins) == output.end())
output.push_back(ins);
}
template <class T> template <class T>
void remove_output(const T& ins) void remove_output(const T& ins);
{
migraph::erase(output, ins);
}
static void backreference(instruction_ref ref) static void backreference(instruction_ref ref);
{
for(auto&& arg : ref->inputs())
arg->add_output(ref);
}
static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins) static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins);
{
ins->replace_argument(old, new_ins);
backreference(ins);
ins->recompute_shape();
}
static void static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args) replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
{
ins->replace(std::move(o), r, std::move(args));
backreference(ins);
}
private: private:
// internal // internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args) void replace(operation o, const shape& r, std::vector<instruction_ref> args);
{
op = std::move(o);
replace(r);
replace(std::move(args));
}
// internal // internal
void replace(std::vector<instruction_ref> args) void replace(std::vector<instruction_ref> args);
{
clear_arguments();
arguments = std::move(args);
}
// internal // internal
void replace_argument(instruction_ref old, instruction_ref new_ins) void replace_argument(instruction_ref old, instruction_ref new_ins);
{
std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this);
}
operation op; operation op;
shape result; shape result;
std::vector<instruction_ref> output; std::vector<instruction_ref> output;
...@@ -175,15 +78,6 @@ struct instruction ...@@ -175,15 +78,6 @@ struct instruction
literal lit; literal lit;
}; };
// TODO: Move to a cpp file
inline shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return op.compute_shape(shapes);
}
} // namespace migraph } // namespace migraph
namespace std { namespace std {
......
#include <migraph/instruction.hpp>
#include <migraph/builtin.hpp>
#include <migraph/erase.hpp>
namespace migraph {
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
instruction::instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {}
void instruction::replace(const shape& r)
{
if(r != result)
{
result = r;
for(auto&& ins : output)
{
assert(ins->name().front() != '@');
ins->recompute_shape();
}
}
}
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); }
void instruction::clear_arguments()
{
for(auto&& arg : arguments)
{
arg->remove_output(*this);
}
arguments.clear();
}
bool operator==(const instruction& i, instruction_ref ref)
{
return std::addressof(i) == std::addressof(*ref);
}
bool instruction::valid(instruction_ref start) const
{
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
return self != i->outputs().end() &&
std::distance(start, i) < std::distance(start, *self);
});
}
bool instruction::valid() const
{
shape computed;
if(op.name() == "@literal")
{
computed = lit.get_shape();
}
else if(op.name() == "@param")
{
computed = result;
}
else
{
try
{
computed = compute_shape(op, arguments);
}
catch(migraph::exception&)
{
return false;
}
}
return result == computed &&
std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->inputs().begin(), i->inputs().end(), *this) !=
i->inputs().end();
});
}
shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const
{
assert(op.name() == "@literal");
return lit;
}
const operation& instruction::get_operator() const { return op; }
std::string instruction::name() const { return op.name(); }
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
const std::vector<instruction_ref>& instruction::outputs() const { return output; }
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
void instruction::add_output(instruction_ref ins)
{
if(std::find(output.begin(), output.end(), ins) == output.end())
output.push_back(ins);
}
template <class T>
void instruction::remove_output(const T& ins)
{
migraph::erase(output, ins);
}
void instruction::backreference(instruction_ref ref)
{
for(auto&& arg : ref->inputs())
arg->add_output(ref);
}
void instruction::replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins)
{
ins->replace_argument(old, new_ins);
backreference(ins);
ins->recompute_shape();
}
void
instruction::replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args)
{
ins->replace(std::move(o), r, std::move(args));
backreference(ins);
}
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
op = std::move(o);
replace(r);
replace(std::move(args));
}
void instruction::replace(std::vector<instruction_ref> args)
{
clear_arguments();
arguments = std::move(args);
}
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{
std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this);
}
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return op.compute_shape(shapes);
}
} // namespace migraph
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment