Commit cf91c2b1 authored by Umang Yadav's avatar Umang Yadav
Browse files

add changes for the eliminate_data_type pass

parent 7d6e6ad7
...@@ -108,37 +108,18 @@ void eliminate_data_type::apply(module& m) const ...@@ -108,37 +108,18 @@ void eliminate_data_type::apply(module& m) const
"scatternd_add", "scatternd_add",
"scatternd_mul", "scatternd_mul",
"scatternd_none"}; "scatternd_none"};
if(unsupported_types.empty() and unsupported_fp8_ops.empty()) if(unsupported_types.empty())
{
return; return;
}
else if(not unsupported_fp8_ops.empty() and not unsupported_types.empty())
{
MIGRAPHX_THROW("eliminate_data_type: specify either unsupported FP8 ops or unsupported "
"data types not both.");
}
else if(unsupported_fp8_ops.empty())
{
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name()[0] == '@') if(ins->name()[0] == '@')
continue; continue;
if(contains(skip_op_names, ins->name())) if(contains(skip_op_names, ins->name()) and not contains(unsupported_ops, ins->name()))
continue; continue;
if(contains(unsupported_ops, "all") or contains(unsupported_ops, ins->name()))
insert_convert_to_supported_type(m, ins, target_type, unsupported_types); insert_convert_to_supported_type(m, ins, target_type, unsupported_types);
} }
}
else
{
std::set<migraphx::shape::type_t> unsupported_fp8_types = {
migraphx::shape::fp8e4m3fnuz_type};
for(auto ins : iterator_for(m))
{
if(not contains(unsupported_fp8_ops, ins->name()))
continue;
insert_convert_to_supported_type(m, ins, target_type, unsupported_fp8_types);
}
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -41,8 +41,8 @@ struct module; ...@@ -41,8 +41,8 @@ struct module;
struct MIGRAPHX_EXPORT eliminate_data_type struct MIGRAPHX_EXPORT eliminate_data_type
{ {
std::set<shape::type_t> unsupported_types; std::set<shape::type_t> unsupported_types;
std::set<std::string> unsupported_fp8_ops;
shape::type_t target_type; shape::type_t target_type;
std::set<std::string> unsupported_ops = {"all"};
std::string name() const { return "eliminate_data_type"; } std::string name() const { return "eliminate_data_type"; }
void apply(module& m) const; void apply(module& m) const;
}; };
......
...@@ -70,7 +70,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -70,7 +70,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
return {normalize_ops{}, return {normalize_ops{},
rewrite_quantization{}, rewrite_quantization{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_data_type{unsupported_types, {}, shape::type_t::float_type}, eliminate_data_type{unsupported_types, shape::type_t::float_type},
dead_code_elimination{}, dead_code_elimination{},
simplify_reshapes{}, simplify_reshapes{},
eliminate_identity{}, eliminate_identity{},
......
...@@ -122,7 +122,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -122,7 +122,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_qdq{}, simplify_qdq{},
enable_pass(not mlir_enabled(), rewrite_quantization{}), enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{}, dead_code_elimination{},
eliminate_data_type{unsupported_types, {}, shape::type_t::float_type}, eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{}, simplify_reshapes{},
eliminate_identity{}, eliminate_identity{},
eliminate_pad{}, eliminate_pad{},
...@@ -141,7 +141,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -141,7 +141,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
eliminate_data_type{{}, unsupported_fp8_ops, shape::float_type}, eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type}, shape::float_type, unsupported_fp8_ops},
dead_code_elimination{}, dead_code_elimination{},
optimize_module{}, optimize_module{},
fuse_pointwise{}, fuse_pointwise{},
......
...@@ -30,13 +30,11 @@ ...@@ -30,13 +30,11 @@
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::module& m, void run_pass(migraphx::module& m, std::set<migraphx::shape::type_t> types)
std::set<migraphx::shape::type_t> types,
std::set<std::string> unsupported_fp8_ops = {})
{ {
migraphx::run_passes(m, migraphx::run_passes(
{migraphx::eliminate_data_type{ m,
std::move(types), unsupported_fp8_ops, migraphx::shape::float_type}, {migraphx::eliminate_data_type{std::move(types), migraphx::shape::float_type},
migraphx::eliminate_identity{}, migraphx::eliminate_identity{},
migraphx::dead_code_elimination{}}); migraphx::dead_code_elimination{}});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment