Commit 2a0eec0a authored by Ahsan Saghir's avatar Ahsan Saghir
Browse files

Fix handling of inputs of different types to return

parent 2a2f82c7
......@@ -64,8 +64,12 @@ void autocast_fp8_pass::apply(module& m) const
migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}),
i));
}
else
{
new_inputs.push_back(m.insert_instruction(ins, migraphx::make_op("identity"), i));
}
}
if(new_inputs.size())
if(new_inputs != inputs)
{
auto new_ins = m.insert_instruction(ins, ins->get_operator(), {new_inputs});
m.replace_instruction(ins, new_ins);
......
......@@ -87,7 +87,7 @@ TEST_CASE(autocast_fp8_2)
EXPECT(m1 == m2);
}
// multiple inputs to return
// multiple inputs (of same type) to return
TEST_CASE(autocast_fp8_3)
{
migraphx::module m1;
......@@ -96,8 +96,7 @@ TEST_CASE(autocast_fp8_3)
auto y = m1.add_parameter("y", {migraphx::shape::fp8e4m3fnuz_type, {1}});
auto sum = m1.add_instruction(migraphx::make_op("add"), x, y);
auto diff = m1.add_instruction(migraphx::make_op("sub"), x, y);
auto result = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), sum, diff);
m1.add_return({result});
m1.add_return({sum, diff});
}
run_pass(m1);
......@@ -109,15 +108,46 @@ TEST_CASE(autocast_fp8_3)
auto x_fp8 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), x_fp32);
auto sum_fp8 = m2.add_instruction(migraphx::make_op("add"), x_fp8, y_fp8);
auto diff_fp8 = m2.add_instruction(migraphx::make_op("sub"), x_fp8, y_fp8);
auto concat_fp8 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), sum_fp8, diff_fp8);
auto result_fp32 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), concat_fp8);
m2.add_return({result_fp32});
auto sum_fp32 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), sum_fp8);
auto diff_fp32 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), diff_fp8);
m2.add_return({sum_fp32, diff_fp32});
}
EXPECT(m1 == m2);
}
// autocast pass does not do any changes
// multiple inputs (of different types) to return
TEST_CASE(autocast_fp8_4)
{
migraphx::module m1;
{
auto x1 = m1.add_parameter("x1", {migraphx::shape::fp8e4m3fnuz_type, {1}});
auto y1 = m1.add_parameter("y1", {migraphx::shape::fp8e4m3fnuz_type, {1}});
auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1}});
auto y2 = m1.add_parameter("y2", {migraphx::shape::float_type, {1}});
auto sum1 = m1.add_instruction(migraphx::make_op("add"), x1, y1);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), x2, y2);
m1.add_return({sum1, sum2});
}
run_pass(m1);
migraphx::module m2;
{
auto x2 = m2.add_parameter("x2", {migraphx::shape::float_type, {1}});
auto y2 = m2.add_parameter("y2", {migraphx::shape::float_type, {1}});
auto y1 = m2.add_parameter("y1", {migraphx::shape::float_type, {1}});
auto x1 = m2.add_parameter("x1", {migraphx::shape::float_type, {1}});
auto y1_fp8 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), y1);
auto x1_fp8 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), x1);
auto sum1 = m2.add_instruction(migraphx::make_op("add"), x1_fp8, y1_fp8);
auto sum2 = m2.add_instruction(migraphx::make_op("add"), x2, y2);
auto result_sum1 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), sum1);
m2.add_return({result_sum1, sum2});
}
EXPECT(m1 == m2);
}
// autocast pass does not do any changes
TEST_CASE(autocast_fp8_5)
{
migraphx::module m1;
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment