"tests/vscode:/vscode.git/clone" did not exist on "955b44191da945d9a3b9f9a1f8c5a57fc891e0c7"
Unverified Commit 1f827a7a authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Fix assertion error during verify and make DCE work with tuples (#1857)

parent 6a303918
...@@ -49,8 +49,10 @@ void dead_code_elimination::apply(module& m) const ...@@ -49,8 +49,10 @@ void dead_code_elimination::apply(module& m) const
if(i == last) if(i == last)
break; break;
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined, // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate] // identity, allocate or tuple_type]
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and if((not i->get_shape().dynamic() and
(i->get_shape().elements() == 0 and
i->get_shape().type() != migraphx::shape::tuple_type)) and
not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and
not i->is_undefined()) not i->is_undefined())
continue; continue;
......
...@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref ...@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
if(ins == std::prev(this->end())) if(ins == std::prev(this->end()))
{ {
// "rep" instruction could be used earlier in the program and moving it at the end
// may cause invalid program, therefore make an identity operation in this case.
return replace_instruction(ins, make_op("identity"), rep); return replace_instruction(ins, make_op("identity"), rep);
} }
......
...@@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names ...@@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
auto mod_inputs = ins->module_inputs(); auto mod_inputs = ins->module_inputs();
auto s = ins->get_shape(); auto s = ins->get_shape();
// Convert back to original type before quantizing the inputs
if(mod_inputs.empty())
{
auto r = m.insert_instruction(
std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins);
m.replace_instruction(ins, r);
}
// Convert each of the inputs that are floating point to fp16 // Convert each of the inputs that are floating point to fp16
auto inputs = ins->inputs(); auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
...@@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names ...@@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
ins, make_op("convert", {{"target_type", shape::half_type}}), input); ins, make_op("convert", {{"target_type", shape::half_type}}), input);
}); });
// Replace inputs // Insert quantized ins
m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs); auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs);
// Convert back to original type after quantizing
if(mod_inputs.empty())
{
converted_ins = m.insert_instruction(
ins, make_op("convert", {{"target_type", s.type()}}), converted_ins);
}
// Replace original instruction
m.replace_instruction(ins, converted_ins);
} }
} }
......
...@@ -232,7 +232,6 @@ TEST_CASE(reused_twice) ...@@ -232,7 +232,6 @@ TEST_CASE(reused_twice)
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
p.debug_print();
EXPECT(std::distance(mm->begin(), mm->end()) != count); EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 4); EXPECT(std::distance(mm->begin(), mm->end()) == 4);
} }
...@@ -274,4 +273,17 @@ TEST_CASE(param_not_eliminated) ...@@ -274,4 +273,17 @@ TEST_CASE(param_not_eliminated)
EXPECT(p == create_program()); EXPECT(p == create_program());
} }
TEST_CASE(tuple_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(tuple_op{}, one, two);
mm->add_return({one, two});
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -186,6 +186,21 @@ struct nop ...@@ -186,6 +186,21 @@ struct nop
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; } migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
}; };
struct tuple_op
{
std::string name() const { return "tuple_op"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>& inputs) const
{
return {inputs};
}
migraphx::argument compute(migraphx::context&,
const migraphx::shape&,
const std::vector<migraphx::argument>& input_args) const
{
return input_args;
}
};
inline migraphx::literal get_2x2(int base = 0) inline migraphx::literal get_2x2(int base = 0)
{ {
return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, return migraphx::literal{{migraphx::shape::float_type, {2, 2}},
......
...@@ -82,13 +82,17 @@ TEST_CASE(param_add) ...@@ -82,13 +82,17 @@ TEST_CASE(param_add)
auto hp1 = mm->add_instruction(migraphx::make_op("convert"), p1); auto hp1 = mm->add_instruction(migraphx::make_op("convert"), p1);
auto hp2 = mm->add_instruction(migraphx::make_op("convert"), p2); auto hp2 = mm->add_instruction(migraphx::make_op("convert"), p2);
auto hs = mm->add_instruction(migraphx::make_op("add"), hp1, hp2); auto hs = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto res = mm->add_instruction( auto fs = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
hs); hs);
if(add_return) if(add_return)
{ {
mm->add_return({res}); mm->add_return({fs});
}
else
{
mm->add_instruction(migraphx::make_op("identity"), {fs});
} }
return p; return p;
...@@ -159,10 +163,10 @@ TEST_CASE(param_add_sub) ...@@ -159,10 +163,10 @@ TEST_CASE(param_add_sub)
auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2); auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2);
auto hdiff = mm->add_instruction( auto hdiff = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), diff); migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), diff);
auto res = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1); auto hadd = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1);
auto r = mm->add_instruction( auto fadd = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), res); migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hadd);
mm->add_return({r}); mm->add_return({fadd});
return p; return p;
}; };
...@@ -258,7 +262,8 @@ TEST_CASE(param_add_sub) ...@@ -258,7 +262,8 @@ TEST_CASE(param_add_sub)
}; };
auto p0 = create_program_float(); auto p0 = create_program_float();
migraphx::run_passes(p0, {migraphx::quantize_fp16_pass{{"all"}}}); migraphx::run_passes(
p0, {migraphx::quantize_fp16_pass{{"all"}}, migraphx::dead_code_elimination{}});
EXPECT(p0 == create_program_fp16()); EXPECT(p0 == create_program_fp16());
auto p1 = create_program_float(); auto p1 = create_program_float();
...@@ -278,7 +283,6 @@ TEST_CASE(literal_add) ...@@ -278,7 +283,6 @@ TEST_CASE(literal_add)
auto l1 = mm->add_literal(migraphx::literal(s, data)); auto l1 = mm->add_literal(migraphx::literal(s, data));
auto l2 = mm->add_literal(migraphx::literal(s, data)); auto l2 = mm->add_literal(migraphx::literal(s, data));
mm->add_instruction(migraphx::make_op("add"), l1, l2); mm->add_instruction(migraphx::make_op("add"), l1, l2);
return p; return p;
}; };
...@@ -291,11 +295,11 @@ TEST_CASE(literal_add) ...@@ -291,11 +295,11 @@ TEST_CASE(literal_add)
auto l1 = mm->add_literal(migraphx::literal(s, data)); auto l1 = mm->add_literal(migraphx::literal(s, data));
auto l2 = mm->add_literal(migraphx::literal(s, data)); auto l2 = mm->add_literal(migraphx::literal(s, data));
auto hs = mm->add_instruction(migraphx::make_op("add"), l1, l2); auto hs = mm->add_instruction(migraphx::make_op("add"), l1, l2);
mm->add_instruction( auto fs = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
hs); hs);
mm->add_instruction(migraphx::make_op("identity"), fs);
return p; return p;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment