Commit 343483a5 authored by charlie's avatar charlie
Browse files

Refactor to make more clear

parent e3ac3847
...@@ -77,10 +77,10 @@ bool operator==(const instruction& i, instruction_ref ref) ...@@ -77,10 +77,10 @@ bool operator==(const instruction& i, instruction_ref ref)
bool instruction::valid(const module& m, bool check_order) const bool instruction::valid(const module& m, bool check_order) const
{ {
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) { return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
(*i).debug_print();
auto self = std::find(i->outputs().begin(), i->outputs().end(), *this); auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
bool ret = self != i->outputs().end(); bool ret = self != i->outputs().end();
// assume argument is in previous module if m.has_instruction(i) is false if(check_order)
if(check_order and m.has_instruction(i))
{ {
// check arguments for this instruction before this instruction // check arguments for this instruction before this instruction
ret = ret and (std::distance(m.begin(), i) < std::distance(m.begin(), *self)); ret = ret and (std::distance(m.begin(), i) < std::distance(m.begin(), *self));
......
...@@ -505,14 +505,22 @@ std::vector<shape> module::get_output_shapes() const ...@@ -505,14 +505,22 @@ std::vector<shape> module::get_output_shapes() const
instruction_ref module::validate() const instruction_ref module::validate() const
{ {
return std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& i) { auto check_invalid = [&](instruction_ref i) {
auto inputs = i.inputs(); auto inputs = (*i).inputs();
bool check_order = std::all_of(inputs.begin(), inputs.end(), [&](auto in) { bool check_order = std::all_of(
return contains(impl->instructions, *in); inputs.begin(), inputs.end(), [&](instruction_ref in) { return has_instruction(in); });
}); return not(*i).valid(*this, check_order);
return !i.valid(*this, check_order); };
});
for(instruction_ref i = impl->instructions.begin(); i != impl->instructions.end(); ++i)
{
if(check_invalid(i))
{
return i;
}
}
return impl->instructions.end();
} }
bool is_borrowed(instruction_ref ins) bool is_borrowed(instruction_ref ins)
......
...@@ -154,6 +154,8 @@ void program::compile(const target& t, compile_options options) ...@@ -154,6 +154,8 @@ void program::compile(const target& t, compile_options options)
auto mods = this->get_modules(); auto mods = this->get_modules();
std::cout << "mods size: " << mods.size() << std::endl;
// Validate and finalize // Validate and finalize
for(const auto& mod : reverse(mods)) for(const auto& mod : reverse(mods))
{ {
......
...@@ -319,10 +319,12 @@ TEST_CASE(multiple_module_dependency) ...@@ -319,10 +319,12 @@ TEST_CASE(multiple_module_dependency)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto* sub = p.create_module("sub"); auto* sub = p.create_module("sub");
auto l1 = mm->add_literal(migraphx::literal(3)); auto l1 = mm->add_literal(migraphx::literal(3));
sub->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2 ,3}}}), l1); // second same literal to make sure instruction_ref is being compared, rahter than the
p.compile(migraphx::ref::target{}); // instructions
p.eval({}); sub->add_literal(migraphx::literal(3));
sub->add_instruction(sum_op{}, l1, l1);
sub->validate();
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment