Commit f7a59edb authored by Paul's avatar Paul
Browse files

Whitelist operators

parent 561456e7
...@@ -28,6 +28,11 @@ struct find_conv_pointwise ...@@ -28,6 +28,11 @@ struct find_conv_pointwise
auto x_ins = r.instructions["x"]; // input after contiguous auto x_ins = r.instructions["x"]; // input after contiguous
auto pm = ins->module_inputs().front(); auto pm = ins->module_inputs().front();
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
// Whitelist pointwise operators
if (std::any_of(pm->begin(), pm->end(), [](const auto& i) {
return not contains({"@literal", "@param", "@return", "convolution", "add", "relu"}, i.name());
}))
return;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
module mm{}; module mm{};
std::unordered_map<instruction_ref, instruction_ref> param_map; std::unordered_map<instruction_ref, instruction_ref> param_map;
......
...@@ -468,7 +468,8 @@ struct mlir_program ...@@ -468,7 +468,8 @@ struct mlir_program
std::string tname = get_device_name(); std::string tname = get_device_name();
// HACK: Since MLIR can't handle the full target name // HACK: Since MLIR can't handle the full target name
auto hacked_tname = tname.substr(0, tname.find(":")); auto hacked_tname = tname.substr(0, tname.find(":"));
mlirMIGraphXAddBackendPipeline(pm.get(), hacked_tname.c_str(), nullptr, nullptr); auto hacked_features = tname.substr(tname.find(":"));
mlirMIGraphXAddBackendPipeline(pm.get(), hacked_tname.c_str(), "", hacked_features.c_str());
mlirPassManagerRun(pm.get(), mmodule.get()); mlirPassManagerRun(pm.get(), mmodule.get());
code_object_op op; code_object_op op;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment