"docs/source/vscode:/vscode.git/clone" did not exist on "f69511ecc618330212e7148265e1c0323d2fa5cf"
Commit f33c7a48 authored by Paul's avatar Paul
Browse files

Formatting

parent 3d9bef13
...@@ -25,13 +25,13 @@ struct instruction ...@@ -25,13 +25,13 @@ struct instruction
void recompute_shape(); void recompute_shape();
void clear_arguments(); void clear_arguments();
friend bool operator==(const instruction& i, instruction_ref ref); friend bool operator==(const instruction& i, instruction_ref ref);
bool valid(instruction_ref start) const; bool valid(instruction_ref start) const;
bool valid() const; bool valid() const;
shape get_shape() const; shape get_shape() const;
const literal& get_literal() const; const literal& get_literal() const;
...@@ -67,10 +67,10 @@ struct instruction ...@@ -67,10 +67,10 @@ struct instruction
// internal // internal
void replace(std::vector<instruction_ref> args); void replace(std::vector<instruction_ref> args);
// internal // internal
void replace_argument(instruction_ref old, instruction_ref new_ins); void replace_argument(instruction_ref old, instruction_ref new_ins);
operation op; operation op;
shape result; shape result;
std::vector<instruction_ref> output; std::vector<instruction_ref> output;
......
...@@ -4,151 +4,156 @@ ...@@ -4,151 +4,156 @@
namespace migraph { namespace migraph {
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args) instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args)) : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{ {
} }
instruction::instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {} instruction::instruction(literal l)
: op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
{
}
void instruction::replace(const shape& r) void instruction::replace(const shape& r)
{
if(r != result)
{ {
if(r != result) result = r;
for(auto&& ins : output)
{ {
result = r; assert(ins->name().front() != '@');
for(auto&& ins : output) ins->recompute_shape();
{
assert(ins->name().front() != '@');
ins->recompute_shape();
}
} }
} }
}
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); } void instruction::recompute_shape() { replace(compute_shape(op, arguments)); }
void instruction::clear_arguments() void instruction::clear_arguments()
{
for(auto&& arg : arguments)
{ {
for(auto&& arg : arguments) arg->remove_output(*this);
{
arg->remove_output(*this);
}
arguments.clear();
} }
arguments.clear();
}
bool operator==(const instruction& i, instruction_ref ref)
{
return std::addressof(i) == std::addressof(*ref);
}
bool operator==(const instruction& i, instruction_ref ref) bool instruction::valid(instruction_ref start) const
{
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
return self != i->outputs().end() &&
std::distance(start, i) < std::distance(start, *self);
});
}
bool instruction::valid() const
{
shape computed;
if(op.name() == "@literal")
{ {
return std::addressof(i) == std::addressof(*ref); computed = lit.get_shape();
} }
else if(op.name() == "@param")
bool instruction::valid(instruction_ref start) const
{ {
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) { computed = result;
auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
return self != i->outputs().end() &&
std::distance(start, i) < std::distance(start, *self);
});
} }
else
bool instruction::valid() const
{ {
shape computed; try
if(op.name() == "@literal")
{
computed = lit.get_shape();
}
else if(op.name() == "@param")
{ {
computed = result; computed = compute_shape(op, arguments);
} }
else catch(migraph::exception&)
{ {
try return false;
{
computed = compute_shape(op, arguments);
}
catch(migraph::exception&)
{
return false;
}
} }
return result == computed &&
std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->inputs().begin(), i->inputs().end(), *this) !=
i->inputs().end();
});
} }
return result == computed && std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
});
}
shape instruction::get_shape() const { return result; } shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const const literal& instruction::get_literal() const
{ {
assert(op.name() == "@literal"); assert(op.name() == "@literal");
return lit; return lit;
} }
const operation& instruction::get_operator() const { return op; } const operation& instruction::get_operator() const { return op; }
std::string instruction::name() const { return op.name(); } std::string instruction::name() const { return op.name(); }
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; } const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
const std::vector<instruction_ref>& instruction::outputs() const { return output; } const std::vector<instruction_ref>& instruction::outputs() const { return output; }
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; } bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); } bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); } bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
void instruction::add_output(instruction_ref ins) void instruction::add_output(instruction_ref ins)
{ {
if(std::find(output.begin(), output.end(), ins) == output.end()) if(std::find(output.begin(), output.end(), ins) == output.end())
output.push_back(ins); output.push_back(ins);
} }
template <class T> template <class T>
void instruction::remove_output(const T& ins) void instruction::remove_output(const T& ins)
{ {
migraph::erase(output, ins); migraph::erase(output, ins);
} }
void instruction::backreference(instruction_ref ref) void instruction::backreference(instruction_ref ref)
{ {
for(auto&& arg : ref->inputs()) for(auto&& arg : ref->inputs())
arg->add_output(ref); arg->add_output(ref);
} }
void instruction::replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins) void instruction::replace_argument(instruction_ref ins,
{ instruction_ref old,
ins->replace_argument(old, new_ins); instruction_ref new_ins)
backreference(ins); {
ins->recompute_shape(); ins->replace_argument(old, new_ins);
} backreference(ins);
ins->recompute_shape();
}
void void instruction::replace(instruction_ref ins,
instruction::replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args) operation o,
{ const shape& r,
ins->replace(std::move(o), r, std::move(args)); std::vector<instruction_ref> args)
backreference(ins); {
} ins->replace(std::move(o), r, std::move(args));
backreference(ins);
}
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args) void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{ {
op = std::move(o); op = std::move(o);
replace(r); replace(r);
replace(std::move(args)); replace(std::move(args));
} }
void instruction::replace(std::vector<instruction_ref> args) void instruction::replace(std::vector<instruction_ref> args)
{ {
clear_arguments(); clear_arguments();
arguments = std::move(args); arguments = std::move(args);
} }
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{ {
std::replace(arguments.begin(), arguments.end(), old, new_ins); std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this); old->remove_output(*this);
} }
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args) shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment