Commit 402c66ab authored by Umang Yadav's avatar Umang Yadav
Browse files

use eliminate_data_type pass instead of eliminate_fp8 pass

parent 5423577a
...@@ -31,6 +31,72 @@ ...@@ -31,6 +31,72 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void insert_convert_to_supported_type(module& m,
instruction_ref ins,
migraphx::shape::type_t target_type,
std::set<migraphx::shape::type_t> unsupported_types)
{
migraphx::shape::type_t orig_type = ins->get_shape().type();
std::vector<instruction_ref> inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](const auto& i) {
if(contains(unsupported_types, i->get_shape().type()))
{
return m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}),
i);
}
else
{
return i;
}
});
// if no change
if(inputs == ins->inputs())
return;
auto op = ins->get_operator();
auto attributes = op.attributes();
if(attributes.contains("general_data_type"))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
}
auto new_ins = m.insert_instruction(ins, op, inputs);
if(orig_type == shape::tuple_type)
{
auto orig_outs = ins->outputs();
if(not std::all_of(orig_outs.begin(), orig_outs.end(), [&](const auto out_ins) {
return out_ins->name() == "get_tuple_elem";
}))
MIGRAPHX_THROW(
"eliminate_data_type: Instruction with tuple output doesn't have all its "
"usages as get_tuple_elem instruction");
std::transform(
orig_outs.begin(), orig_outs.end(), orig_outs.begin(), [&](const auto out_ins) {
auto gte_ins = m.insert_instruction(ins, out_ins->get_operator(), new_ins);
auto orig_out_type = out_ins->get_shape().type();
if(contains(unsupported_types, orig_out_type))
{
auto gte_convert = m.insert_instruction(
ins, make_op("convert", {{"target_type", orig_out_type}}), gte_ins);
return m.replace_instruction(out_ins, gte_convert);
}
else
{
return m.replace_instruction(out_ins, gte_ins);
}
});
}
else
{
auto convert_back_ins = m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}),
new_ins);
m.replace_instruction(ins, convert_back_ins);
}
}
void eliminate_data_type::apply(module& m) const void eliminate_data_type::apply(module& m) const
{ {
static const std::vector<std::string> skip_op_names = {"convert", static const std::vector<std::string> skip_op_names = {"convert",
...@@ -42,31 +108,36 @@ void eliminate_data_type::apply(module& m) const ...@@ -42,31 +108,36 @@ void eliminate_data_type::apply(module& m) const
"scatternd_add", "scatternd_add",
"scatternd_mul", "scatternd_mul",
"scatternd_none"}; "scatternd_none"};
for(auto ins : iterator_for(m)) if(unsupported_types.empty() and unsupported_types.empty())
{
return;
}
else if(not unsupported_fp8_ops.empty() and not unsupported_types.empty())
{
MIGRAPHX_THROW("eliminate_data_type: specify either unsupported FP8 ops or unsupported "
"data types not both.");
}
else if(unsupported_fp8_ops.empty())
{
for(auto ins : iterator_for(m))
{
if(ins->name()[0] == '@')
continue;
if(contains(skip_op_names, ins->name()))
continue;
insert_convert_to_supported_type(m, ins, target_type, unsupported_types);
}
}
else
{ {
if(ins->name()[0] == '@') std::set<migraphx::shape::type_t> unsupported_fp8_types = {
continue; migraphx::shape::fp8e4m3fnuz_type};
if(contains(skip_op_names, ins->name())) for(auto ins : iterator_for(m))
continue;
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
if(types.count(i->get_shape().type()) == 0)
return i;
return m.insert_instruction(ins, make_op("convert", {{"target_type", target_type}}), i);
});
if(inputs == ins->inputs())
continue;
auto op = ins->get_operator();
auto attributes = op.attributes();
if(attributes.contains("general_data_type"))
{ {
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value()); if(not contains(unsupported_fp8_ops, ins->name()))
continue;
insert_convert_to_supported_type(m, ins, target_type, unsupported_fp8_types);
} }
auto old_type = ins->get_shape().type();
auto out = m.insert_instruction(ins, op, inputs);
auto convert =
m.insert_instruction(ins, make_op("convert", {{"target_type", old_type}}), out);
m.replace_instruction(ins, convert);
} }
} }
......
...@@ -40,7 +40,8 @@ struct module; ...@@ -40,7 +40,8 @@ struct module;
*/ */
struct MIGRAPHX_EXPORT eliminate_data_type struct MIGRAPHX_EXPORT eliminate_data_type
{ {
std::set<shape::type_t> types; std::set<shape::type_t> unsupported_types;
std::set<std::string> unsupported_fp8_ops;
shape::type_t target_type; shape::type_t target_type;
std::string name() const { return "eliminate_data_type"; } std::string name() const { return "eliminate_data_type"; }
void apply(module& m) const; void apply(module& m) const;
......
...@@ -70,7 +70,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -70,7 +70,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
return {normalize_ops{}, return {normalize_ops{},
rewrite_quantization{}, rewrite_quantization{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type}, eliminate_data_type{unsupported_types, {}, shape::type_t::float_type},
dead_code_elimination{}, dead_code_elimination{},
simplify_reshapes{}, simplify_reshapes{},
eliminate_identity{}, eliminate_identity{},
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/shape.hpp"
#include <migraphx/adjust_allocation.hpp> #include <migraphx/adjust_allocation.hpp>
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/check_context.hpp> #include <migraphx/check_context.hpp>
...@@ -123,7 +124,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -123,7 +124,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_qdq{}, simplify_qdq{},
enable_pass(not mlir_enabled(), rewrite_quantization{}), enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{}, dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type}, eliminate_data_type{unsupported_types, {}, shape::type_t::float_type},
simplify_reshapes{}, simplify_reshapes{},
eliminate_identity{}, eliminate_identity{},
eliminate_pad{}, eliminate_pad{},
...@@ -142,7 +143,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -142,7 +143,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
eliminate_fp8{unsupported_fp8_ops}, eliminate_data_type{{}, unsupported_fp8_ops, shape::float_type},
dead_code_elimination{}, dead_code_elimination{},
optimize_module{}, optimize_module{},
fuse_pointwise{}, fuse_pointwise{},
......
...@@ -30,13 +30,15 @@ ...@@ -30,13 +30,15 @@
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::module& m, std::set<migraphx::shape::type_t> types) void run_pass(migraphx::module& m,
std::set<migraphx::shape::type_t> types,
std::set<std::string> unsupported_fp8_ops = {})
{ {
migraphx::run_passes( migraphx::run_passes(m,
m, {migraphx::eliminate_data_type{
{migraphx::eliminate_data_type{std::move(types), migraphx::shape::float_type}, std::move(types), unsupported_fp8_ops, migraphx::shape::float_type},
migraphx::eliminate_identity{}, migraphx::eliminate_identity{},
migraphx::dead_code_elimination{}}); migraphx::dead_code_elimination{}});
} }
TEST_CASE(simple) TEST_CASE(simple)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment