"git@developer.sourcefind.cn:chenpangpang/parler-tts.git" did not exist on "b1fb784463a02192c785cbcab4dfe4f9b20dce88"
Commit 2a0eec0a authored by Ahsan Saghir's avatar Ahsan Saghir
Browse files

Fix handling of inputs of different types to return

parent 2a2f82c7
...@@ -64,8 +64,12 @@ void autocast_fp8_pass::apply(module& m) const ...@@ -64,8 +64,12 @@ void autocast_fp8_pass::apply(module& m) const
migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}), migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}),
i)); i));
} }
else
{
new_inputs.push_back(m.insert_instruction(ins, migraphx::make_op("identity"), i));
}
} }
if(new_inputs.size()) if(new_inputs != inputs)
{ {
auto new_ins = m.insert_instruction(ins, ins->get_operator(), {new_inputs}); auto new_ins = m.insert_instruction(ins, ins->get_operator(), {new_inputs});
m.replace_instruction(ins, new_ins); m.replace_instruction(ins, new_ins);
......
...@@ -87,7 +87,7 @@ TEST_CASE(autocast_fp8_2) ...@@ -87,7 +87,7 @@ TEST_CASE(autocast_fp8_2)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
// multiple inputs to return // multiple inputs (of same type) to return
TEST_CASE(autocast_fp8_3) TEST_CASE(autocast_fp8_3)
{ {
migraphx::module m1; migraphx::module m1;
...@@ -96,8 +96,7 @@ TEST_CASE(autocast_fp8_3) ...@@ -96,8 +96,7 @@ TEST_CASE(autocast_fp8_3)
auto y = m1.add_parameter("y", {migraphx::shape::fp8e4m3fnuz_type, {1}}); auto y = m1.add_parameter("y", {migraphx::shape::fp8e4m3fnuz_type, {1}});
auto sum = m1.add_instruction(migraphx::make_op("add"), x, y); auto sum = m1.add_instruction(migraphx::make_op("add"), x, y);
auto diff = m1.add_instruction(migraphx::make_op("sub"), x, y); auto diff = m1.add_instruction(migraphx::make_op("sub"), x, y);
auto result = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), sum, diff); m1.add_return({sum, diff});
m1.add_return({result});
} }
run_pass(m1); run_pass(m1);
...@@ -109,15 +108,46 @@ TEST_CASE(autocast_fp8_3) ...@@ -109,15 +108,46 @@ TEST_CASE(autocast_fp8_3)
auto x_fp8 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), x_fp32); auto x_fp8 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), x_fp32);
auto sum_fp8 = m2.add_instruction(migraphx::make_op("add"), x_fp8, y_fp8); auto sum_fp8 = m2.add_instruction(migraphx::make_op("add"), x_fp8, y_fp8);
auto diff_fp8 = m2.add_instruction(migraphx::make_op("sub"), x_fp8, y_fp8); auto diff_fp8 = m2.add_instruction(migraphx::make_op("sub"), x_fp8, y_fp8);
auto concat_fp8 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), sum_fp8, diff_fp8); auto sum_fp32 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), sum_fp8);
auto result_fp32 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), concat_fp8); auto diff_fp32 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), diff_fp8);
m2.add_return({result_fp32}); m2.add_return({sum_fp32, diff_fp32});
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
// autocast pass does not do any changes // multiple inputs (of different types) to return
TEST_CASE(autocast_fp8_4) TEST_CASE(autocast_fp8_4)
{
migraphx::module m1;
{
auto x1 = m1.add_parameter("x1", {migraphx::shape::fp8e4m3fnuz_type, {1}});
auto y1 = m1.add_parameter("y1", {migraphx::shape::fp8e4m3fnuz_type, {1}});
auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1}});
auto y2 = m1.add_parameter("y2", {migraphx::shape::float_type, {1}});
auto sum1 = m1.add_instruction(migraphx::make_op("add"), x1, y1);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), x2, y2);
m1.add_return({sum1, sum2});
}
run_pass(m1);
migraphx::module m2;
{
auto x2 = m2.add_parameter("x2", {migraphx::shape::float_type, {1}});
auto y2 = m2.add_parameter("y2", {migraphx::shape::float_type, {1}});
auto y1 = m2.add_parameter("y1", {migraphx::shape::float_type, {1}});
auto x1 = m2.add_parameter("x1", {migraphx::shape::float_type, {1}});
auto y1_fp8 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), y1);
auto x1_fp8 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), x1);
auto sum1 = m2.add_instruction(migraphx::make_op("add"), x1_fp8, y1_fp8);
auto sum2 = m2.add_instruction(migraphx::make_op("add"), x2, y2);
auto result_sum1 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), sum1);
m2.add_return({result_sum1, sum2});
}
EXPECT(m1 == m2);
}
// autocast pass does not do any changes
TEST_CASE(autocast_fp8_5)
{ {
migraphx::module m1; migraphx::module m1;
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment