Commit 5423577a authored by Umang Yadav's avatar Umang Yadav
Browse files

use updated eliminate_fp8 pass

parent 381b2d9e
...@@ -24,8 +24,10 @@ ...@@ -24,8 +24,10 @@
#include <iterator> #include <iterator>
#include <utility> #include <utility>
#include <migraphx/eliminate_fp8.hpp> #include <migraphx/eliminate_fp8.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -39,29 +41,67 @@ void eliminate_fp8::apply(module& m) const ...@@ -39,29 +41,67 @@ void eliminate_fp8::apply(module& m) const
{ {
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(not contains(op_names, ins->name()) or if(not contains(op_names, ins->name()))
ins->get_shape().type() != migraphx::shape::fp8e4m3fnuz_type)
continue; continue;
migraphx::shape::type_t orig_type = ins->get_shape().type(); migraphx::shape::type_t orig_type = ins->get_shape().type();
std::vector<instruction_ref> orig_inputs = ins->inputs(); std::vector<instruction_ref> inputs = ins->inputs();
std::vector<instruction_ref> new_inputs; migraphx::transform_if(
std::transform(orig_inputs.begin(), inputs.begin(),
orig_inputs.end(), inputs.end(),
std::back_inserter(new_inputs), inputs.begin(),
[&](const auto& i) { [&](const auto& i) { return i->get_shape().type() == shape::fp8e4m3fnuz_type; },
return m.insert_instruction( [&](const auto& i) {
ins, return m.insert_instruction(
migraphx::make_op( ins,
"convert", {{"target_type", migraphx::to_value(target_type)}}), migraphx::make_op("convert",
i); {{"target_type", migraphx::to_value(target_type)}}),
}); i);
});
if(inputs == ins->inputs())
{
return;
}
auto op = ins->get_operator();
auto attributes = op.attributes();
if(attributes.contains("general_data_type"))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
}
auto new_ins = m.insert_instruction(ins, op, inputs);
if(orig_type == shape::tuple_type)
{
auto orig_outs = ins->outputs();
if(not std::all_of(orig_outs.begin(), orig_outs.end(), [&](const auto out_ins) {
return out_ins->name() == "get_tuple_elem";
}))
MIGRAPHX_THROW("EliminateFP8: Instruction with tuple output doesn't have all its "
"usages as get_tuple_elem instruction");
auto new_ins = m.insert_instruction(ins, ins->get_operator(), {new_inputs}); std::transform(
auto convert_back_ins = m.insert_instruction( orig_outs.begin(), orig_outs.end(), orig_outs.begin(), [&](const auto out_ins) {
ins, auto gte_ins = m.insert_instruction(ins, out_ins->get_operator(), new_ins);
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}), if(out_ins->get_shape().type() == shape::type_t::fp8e4m3fnuz_type)
new_ins); {
m.replace_instruction(ins, convert_back_ins); auto gte_convert = m.insert_instruction(
ins,
make_op("convert", {{"target_type", shape::type_t::fp8e4m3fnuz_type}}),
gte_ins);
return m.replace_instruction(out_ins, gte_convert);
}
else
{
return m.replace_instruction(out_ins, gte_ins);
}
});
}
else
{
auto convert_back_ins = m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}),
new_ins);
m.replace_instruction(ins, convert_back_ins);
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment