Commit 402c66ab authored by Umang Yadav's avatar Umang Yadav
Browse files

use eliminate_data_type pass instead of eliminate_fp8 pass

parent 5423577a
......@@ -31,6 +31,72 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void insert_convert_to_supported_type(module& m,
instruction_ref ins,
migraphx::shape::type_t target_type,
std::set<migraphx::shape::type_t> unsupported_types)
{
migraphx::shape::type_t orig_type = ins->get_shape().type();
std::vector<instruction_ref> inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](const auto& i) {
if(contains(unsupported_types, i->get_shape().type()))
{
return m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}),
i);
}
else
{
return i;
}
});
// if no change
if(inputs == ins->inputs())
return;
auto op = ins->get_operator();
auto attributes = op.attributes();
if(attributes.contains("general_data_type"))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
}
auto new_ins = m.insert_instruction(ins, op, inputs);
if(orig_type == shape::tuple_type)
{
auto orig_outs = ins->outputs();
if(not std::all_of(orig_outs.begin(), orig_outs.end(), [&](const auto out_ins) {
return out_ins->name() == "get_tuple_elem";
}))
MIGRAPHX_THROW(
"eliminate_data_type: Instruction with tuple output doesn't have all its "
"usages as get_tuple_elem instruction");
std::transform(
orig_outs.begin(), orig_outs.end(), orig_outs.begin(), [&](const auto out_ins) {
auto gte_ins = m.insert_instruction(ins, out_ins->get_operator(), new_ins);
auto orig_out_type = out_ins->get_shape().type();
if(contains(unsupported_types, orig_out_type))
{
auto gte_convert = m.insert_instruction(
ins, make_op("convert", {{"target_type", orig_out_type}}), gte_ins);
return m.replace_instruction(out_ins, gte_convert);
}
else
{
return m.replace_instruction(out_ins, gte_ins);
}
});
}
else
{
auto convert_back_ins = m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}),
new_ins);
m.replace_instruction(ins, convert_back_ins);
}
}
void eliminate_data_type::apply(module& m) const
{
static const std::vector<std::string> skip_op_names = {"convert",
......@@ -42,31 +108,36 @@ void eliminate_data_type::apply(module& m) const
"scatternd_add",
"scatternd_mul",
"scatternd_none"};
for(auto ins : iterator_for(m))
if(unsupported_types.empty() and unsupported_types.empty())
{
return;
}
else if(not unsupported_fp8_ops.empty() and not unsupported_types.empty())
{
MIGRAPHX_THROW("eliminate_data_type: specify either unsupported FP8 ops or unsupported "
"data types not both.");
}
else if(unsupported_fp8_ops.empty())
{
for(auto ins : iterator_for(m))
{
if(ins->name()[0] == '@')
continue;
if(contains(skip_op_names, ins->name()))
continue;
insert_convert_to_supported_type(m, ins, target_type, unsupported_types);
}
}
else
{
if(ins->name()[0] == '@')
continue;
if(contains(skip_op_names, ins->name()))
continue;
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
if(types.count(i->get_shape().type()) == 0)
return i;
return m.insert_instruction(ins, make_op("convert", {{"target_type", target_type}}), i);
});
if(inputs == ins->inputs())
continue;
auto op = ins->get_operator();
auto attributes = op.attributes();
if(attributes.contains("general_data_type"))
std::set<migraphx::shape::type_t> unsupported_fp8_types = {
migraphx::shape::fp8e4m3fnuz_type};
for(auto ins : iterator_for(m))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
if(not contains(unsupported_fp8_ops, ins->name()))
continue;
insert_convert_to_supported_type(m, ins, target_type, unsupported_fp8_types);
}
auto old_type = ins->get_shape().type();
auto out = m.insert_instruction(ins, op, inputs);
auto convert =
m.insert_instruction(ins, make_op("convert", {{"target_type", old_type}}), out);
m.replace_instruction(ins, convert);
}
}
......
......@@ -40,7 +40,8 @@ struct module;
*/
struct MIGRAPHX_EXPORT eliminate_data_type
{
std::set<shape::type_t> types;
std::set<shape::type_t> unsupported_types;
std::set<std::string> unsupported_fp8_ops;
shape::type_t target_type;
std::string name() const { return "eliminate_data_type"; }
void apply(module& m) const;
......
......@@ -70,7 +70,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
return {normalize_ops{},
rewrite_quantization{},
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
eliminate_data_type{unsupported_types, {}, shape::type_t::float_type},
dead_code_elimination{},
simplify_reshapes{},
eliminate_identity{},
......
......@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "migraphx/shape.hpp"
#include <migraphx/adjust_allocation.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/check_context.hpp>
......@@ -123,7 +124,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_qdq{},
enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
eliminate_data_type{unsupported_types, {}, shape::type_t::float_type},
simplify_reshapes{},
eliminate_identity{},
eliminate_pad{},
......@@ -142,7 +143,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{},
eliminate_fp8{unsupported_fp8_ops},
eliminate_data_type{{}, unsupported_fp8_ops, shape::float_type},
dead_code_elimination{},
optimize_module{},
fuse_pointwise{},
......
......@@ -30,13 +30,15 @@
#include <test.hpp>
void run_pass(migraphx::module& m, std::set<migraphx::shape::type_t> types)
void run_pass(migraphx::module& m,
std::set<migraphx::shape::type_t> types,
std::set<std::string> unsupported_fp8_ops = {})
{
migraphx::run_passes(
m,
{migraphx::eliminate_data_type{std::move(types), migraphx::shape::float_type},
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{}});
migraphx::run_passes(m,
{migraphx::eliminate_data_type{
std::move(types), unsupported_fp8_ops, migraphx::shape::float_type},
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{}});
}
TEST_CASE(simple)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment