Unverified Commit 1f827a7a authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Fix assertion error during verify and make DCE work with tuples (#1857)

parent 6a303918
...@@ -49,8 +49,10 @@ void dead_code_elimination::apply(module& m) const ...@@ -49,8 +49,10 @@ void dead_code_elimination::apply(module& m) const
if(i == last) if(i == last)
break; break;
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined, // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate] // identity, allocate or tuple_type]
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and if((not i->get_shape().dynamic() and
(i->get_shape().elements() == 0 and
i->get_shape().type() != migraphx::shape::tuple_type)) and
not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and
not i->is_undefined()) not i->is_undefined())
continue; continue;
......
...@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref ...@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
if(ins == std::prev(this->end())) if(ins == std::prev(this->end()))
{ {
// "rep" instruction could be used earlier in the program and moving it at the end
// may cause invalid program, therefore make an identity operation in this case.
return replace_instruction(ins, make_op("identity"), rep); return replace_instruction(ins, make_op("identity"), rep);
} }
......
...@@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names ...@@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
auto mod_inputs = ins->module_inputs(); auto mod_inputs = ins->module_inputs();
auto s = ins->get_shape(); auto s = ins->get_shape();
// Convert back to original type before quantizing the inputs
if(mod_inputs.empty())
{
auto r = m.insert_instruction(
std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins);
m.replace_instruction(ins, r);
}
// Convert each of the inputs that are floating point to fp16 // Convert each of the inputs that are floating point to fp16
auto inputs = ins->inputs(); auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
...@@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names ...@@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
ins, make_op("convert", {{"target_type", shape::half_type}}), input); ins, make_op("convert", {{"target_type", shape::half_type}}), input);
}); });
// Replace inputs // Insert quantized ins
m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs); auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs);
// Convert back to original type after quantizing
if(mod_inputs.empty())
{
converted_ins = m.insert_instruction(
ins, make_op("convert", {{"target_type", s.type()}}), converted_ins);
}
// Replace original instruction
m.replace_instruction(ins, converted_ins);
} }
} }
......
...@@ -232,7 +232,6 @@ TEST_CASE(reused_twice) ...@@ -232,7 +232,6 @@ TEST_CASE(reused_twice)
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
p.debug_print();
EXPECT(std::distance(mm->begin(), mm->end()) != count); EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 4); EXPECT(std::distance(mm->begin(), mm->end()) == 4);
} }
...@@ -274,4 +273,17 @@ TEST_CASE(param_not_eliminated) ...@@ -274,4 +273,17 @@ TEST_CASE(param_not_eliminated)
EXPECT(p == create_program()); EXPECT(p == create_program());
} }
TEST_CASE(tuple_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(tuple_op{}, one, two);
mm->add_return({one, two});
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -186,6 +186,21 @@ struct nop ...@@ -186,6 +186,21 @@ struct nop
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; } migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
}; };
struct tuple_op
{
std::string name() const { return "tuple_op"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
return {inputs};
}
migraphx::argument compute(migraphx::context&,
const migraphx::shape&,
const std::vector<migraphx::argument>& input_args) const
{
return input_args;
}
};
inline migraphx::literal get_2x2(int base = 0) inline migraphx::literal get_2x2(int base = 0)
{ {
return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, return migraphx::literal{{migraphx::shape::float_type, {2, 2}},
......
...@@ -82,13 +82,17 @@ TEST_CASE(param_add) ...@@ -82,13 +82,17 @@ TEST_CASE(param_add)
auto hp1 = mm->add_instruction(migraphx::make_op("convert"), p1); auto hp1 = mm->add_instruction(migraphx::make_op("convert"), p1);
auto hp2 = mm->add_instruction(migraphx::make_op("convert"), p2); auto hp2 = mm->add_instruction(migraphx::make_op("convert"), p2);
auto hs = mm->add_instruction(migraphx::make_op("add"), hp1, hp2); auto hs = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto res = mm->add_instruction( auto fs = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
hs); hs);
if(add_return) if(add_return)
{ {
mm->add_return({res}); mm->add_return({fs});
}
else
{
mm->add_instruction(migraphx::make_op("identity"), {fs});
} }
return p; return p;
...@@ -159,10 +163,10 @@ TEST_CASE(param_add_sub) ...@@ -159,10 +163,10 @@ TEST_CASE(param_add_sub)
auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2); auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2);
auto hdiff = mm->add_instruction( auto hdiff = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), diff); migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), diff);
auto res = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1); auto hadd = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1);
auto r = mm->add_instruction( auto fadd = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), res); migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hadd);
mm->add_return({r}); mm->add_return({fadd});
return p; return p;
}; };
...@@ -258,7 +262,8 @@ TEST_CASE(param_add_sub) ...@@ -258,7 +262,8 @@ TEST_CASE(param_add_sub)
}; };
auto p0 = create_program_float(); auto p0 = create_program_float();
migraphx::run_passes(p0, {migraphx::quantize_fp16_pass{{"all"}}}); migraphx::run_passes(
p0, {migraphx::quantize_fp16_pass{{"all"}}, migraphx::dead_code_elimination{}});
EXPECT(p0 == create_program_fp16()); EXPECT(p0 == create_program_fp16());
auto p1 = create_program_float(); auto p1 = create_program_float();
...@@ -278,7 +283,6 @@ TEST_CASE(literal_add) ...@@ -278,7 +283,6 @@ TEST_CASE(literal_add)
auto l1 = mm->add_literal(migraphx::literal(s, data)); auto l1 = mm->add_literal(migraphx::literal(s, data));
auto l2 = mm->add_literal(migraphx::literal(s, data)); auto l2 = mm->add_literal(migraphx::literal(s, data));
mm->add_instruction(migraphx::make_op("add"), l1, l2); mm->add_instruction(migraphx::make_op("add"), l1, l2);
return p; return p;
}; };
...@@ -291,11 +295,11 @@ TEST_CASE(literal_add) ...@@ -291,11 +295,11 @@ TEST_CASE(literal_add)
auto l1 = mm->add_literal(migraphx::literal(s, data)); auto l1 = mm->add_literal(migraphx::literal(s, data));
auto l2 = mm->add_literal(migraphx::literal(s, data)); auto l2 = mm->add_literal(migraphx::literal(s, data));
auto hs = mm->add_instruction(migraphx::make_op("add"), l1, l2); auto hs = mm->add_instruction(migraphx::make_op("add"), l1, l2);
mm->add_instruction( auto fs = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
hs); hs);
mm->add_instruction(migraphx::make_op("identity"), fs);
return p; return p;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment