Commit 6b5c64ff authored by Paul's avatar Paul
Browse files

Format

parent 02b0095c
...@@ -31,7 +31,7 @@ struct mlir_conv ...@@ -31,7 +31,7 @@ struct mlir_conv
if(inputs.size() < 2) if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size(); auto n = inputs.size();
return op.compute_shape({inputs[n -2], inputs[n - 1]}); return op.compute_shape({inputs[n - 2], inputs[n - 1]});
} }
}; };
MIGRAPHX_REGISTER_OP(mlir_conv); MIGRAPHX_REGISTER_OP(mlir_conv);
...@@ -86,7 +86,8 @@ struct find_conv_pointwise ...@@ -86,7 +86,8 @@ struct find_conv_pointwise
std::back_inserter(inputs), std::back_inserter(inputs),
[&](auto input) { return input != conv_ins; }); [&](auto input) { return input != conv_ins; });
inputs.insert(inputs.end(), conv_ins->inputs().begin(), conv_ins->inputs().end()); inputs.insert(inputs.end(), conv_ins->inputs().begin(), conv_ins->inputs().end());
m.replace_instruction(ins, mlir_conv{conv_ins->get_operator()}, inputs, ins->module_inputs()); m.replace_instruction(
ins, mlir_conv{conv_ins->get_operator()}, inputs, ins->module_inputs());
} }
}; };
} // namespace } // namespace
......
...@@ -18,8 +18,7 @@ code_object_op compile_mlir(const context& ctx, const module& m); ...@@ -18,8 +18,7 @@ code_object_op compile_mlir(const context& ctx, const module& m);
instruction_ref insert_mlir(module& m, instruction_ref insert_mlir(module& m,
instruction_ref ins, instruction_ref ins,
code_object_op co, code_object_op co,
const std::vector<instruction_ref>& inputs); const std::vector<instruction_ref>& inputs);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -8,18 +8,11 @@ namespace migraphx { ...@@ -8,18 +8,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct mlir_compiler : compiler<mlir_compiler> struct mlir_compiler : compiler<mlir_compiler>
{ {
std::vector<std::string> names() const std::vector<std::string> names() const { return {"gpu::mlir_conv"}; }
{
return {"gpu::mlir_conv"};
}
operation compile_op(context&, const std::vector<shape>&, const value&) const operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
{
return {};
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const
{ {
...@@ -32,7 +25,7 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -32,7 +25,7 @@ struct mlir_compiler : compiler<mlir_compiler>
{ {
return [=](module& m, instruction_ref ins) { return [=](module& m, instruction_ref ins) {
auto mlir = insert_mlir(m, ins, co, ins->inputs()); auto mlir = insert_mlir(m, ins, co, ins->inputs());
m.replace_instruction(ins, mlir); m.replace_instruction(ins, mlir);
}; };
} }
}; };
......
...@@ -545,7 +545,7 @@ code_object_op compile_mlir(const context&, const module& m) ...@@ -545,7 +545,7 @@ code_object_op compile_mlir(const context&, const module& m)
auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace) if(trace)
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
auto co = mp.compile(); auto co = mp.compile();
co.output = m.get_output_shapes().front(); co.output = m.get_output_shapes().front();
return co; return co;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment