Commit cf91c2b1 authored by Umang Yadav's avatar Umang Yadav
Browse files

add changes for the eliminate_data_type pass

parent 7d6e6ad7
......@@ -108,36 +108,17 @@ void eliminate_data_type::apply(module& m) const
"scatternd_add",
"scatternd_mul",
"scatternd_none"};
if(unsupported_types.empty() and unsupported_fp8_ops.empty())
{
if(unsupported_types.empty())
return;
}
else if(not unsupported_fp8_ops.empty() and not unsupported_types.empty())
{
MIGRAPHX_THROW("eliminate_data_type: specify either unsupported FP8 ops or unsupported "
"data types not both.");
}
else if(unsupported_fp8_ops.empty())
for(auto ins : iterator_for(m))
{
for(auto ins : iterator_for(m))
{
if(ins->name()[0] == '@')
continue;
if(contains(skip_op_names, ins->name()))
continue;
if(ins->name()[0] == '@')
continue;
if(contains(skip_op_names, ins->name()) and not contains(unsupported_ops, ins->name()))
continue;
if(contains(unsupported_ops, "all") or contains(unsupported_ops, ins->name()))
insert_convert_to_supported_type(m, ins, target_type, unsupported_types);
}
}
else
{
std::set<migraphx::shape::type_t> unsupported_fp8_types = {
migraphx::shape::fp8e4m3fnuz_type};
for(auto ins : iterator_for(m))
{
if(not contains(unsupported_fp8_ops, ins->name()))
continue;
insert_convert_to_supported_type(m, ins, target_type, unsupported_fp8_types);
}
}
}
......
......@@ -41,8 +41,8 @@ struct module;
struct MIGRAPHX_EXPORT eliminate_data_type
{
std::set<shape::type_t> unsupported_types;
std::set<std::string> unsupported_fp8_ops;
shape::type_t target_type;
std::set<std::string> unsupported_ops = {"all"};
std::string name() const { return "eliminate_data_type"; }
void apply(module& m) const;
};
......
......@@ -70,7 +70,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
return {normalize_ops{},
rewrite_quantization{},
dead_code_elimination{},
eliminate_data_type{unsupported_types, {}, shape::type_t::float_type},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
dead_code_elimination{},
simplify_reshapes{},
eliminate_identity{},
......
......@@ -122,7 +122,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_qdq{},
enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{},
eliminate_data_type{unsupported_types, {}, shape::type_t::float_type},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{},
eliminate_identity{},
eliminate_pad{},
......@@ -141,7 +141,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{},
eliminate_data_type{{}, unsupported_fp8_ops, shape::float_type},
eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type}, shape::float_type, unsupported_fp8_ops},
dead_code_elimination{},
optimize_module{},
fuse_pointwise{},
......
......@@ -30,15 +30,13 @@
#include <test.hpp>
void run_pass(migraphx::module& m,
std::set<migraphx::shape::type_t> types,
std::set<std::string> unsupported_fp8_ops = {})
void run_pass(migraphx::module& m, std::set<migraphx::shape::type_t> types)
{
migraphx::run_passes(m,
{migraphx::eliminate_data_type{
std::move(types), unsupported_fp8_ops, migraphx::shape::float_type},
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{}});
migraphx::run_passes(
m,
{migraphx::eliminate_data_type{std::move(types), migraphx::shape::float_type},
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{}});
}
TEST_CASE(simple)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment