Commit a22b422a authored by Paul's avatar Paul
Browse files

Add test for operand alias

parent d0575532
...@@ -69,6 +69,8 @@ struct instruction ...@@ -69,6 +69,8 @@ struct instruction
static void static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args); replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
static instruction_ref get_output_alias(instruction_ref ins);
private: private:
// internal // internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args); void replace(operation o, const shape& r, std::vector<instruction_ref> args);
......
...@@ -161,12 +161,24 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) ...@@ -161,12 +161,24 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this); old->remove_output(*this);
} }
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args) std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args)
{ {
std::vector<shape> shapes(args.size()); std::vector<shape> shapes(args.size());
std::transform( std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); }); args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return op.compute_shape(shapes); return shapes;
}
instruction_ref instruction::get_output_alias(instruction_ref ins)
{
auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs()));
if(i < 0) return ins;
return get_output_alias(ins->inputs().at(i));
}
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
return op.compute_shape(compute_shapes(args));
} }
} // namespace migraph } // namespace migraph
...@@ -79,6 +79,7 @@ struct pass_op ...@@ -79,6 +79,7 @@ struct pass_op
return {}; return {};
return inputs.front(); return inputs.front();
} }
int output_alias(const std::vector<migraph::shape>&) const { return 0; }
}; };
struct pass_standard_op struct pass_standard_op
...@@ -103,6 +104,7 @@ struct pass_standard_op ...@@ -103,6 +104,7 @@ struct pass_standard_op
return {}; return {};
return inputs.front(); return inputs.front();
} }
int output_alias(const std::vector<migraph::shape>&) const { return 0; }
}; };
struct nop struct nop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment