Commit 6b5c64ff authored by Paul's avatar Paul
Browse files

Format

parent 02b0095c
......@@ -31,7 +31,7 @@ struct mlir_conv
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size();
return op.compute_shape({inputs[n -2], inputs[n - 1]});
return op.compute_shape({inputs[n - 2], inputs[n - 1]});
}
};
MIGRAPHX_REGISTER_OP(mlir_conv);
......@@ -86,7 +86,8 @@ struct find_conv_pointwise
std::back_inserter(inputs),
[&](auto input) { return input != conv_ins; });
inputs.insert(inputs.end(), conv_ins->inputs().begin(), conv_ins->inputs().end());
m.replace_instruction(ins, mlir_conv{conv_ins->get_operator()}, inputs, ins->module_inputs());
m.replace_instruction(
ins, mlir_conv{conv_ins->get_operator()}, inputs, ins->module_inputs());
}
};
} // namespace
......
......@@ -18,8 +18,7 @@ code_object_op compile_mlir(const context& ctx, const module& m);
instruction_ref insert_mlir(module& m,
instruction_ref ins,
code_object_op co,
const std::vector<instruction_ref>& inputs);
const std::vector<instruction_ref>& inputs);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -8,18 +8,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct mlir_compiler : compiler<mlir_compiler>
{
std::vector<std::string> names() const
{
return {"gpu::mlir_conv"};
}
std::vector<std::string> names() const { return {"gpu::mlir_conv"}; }
operation compile_op(context&, const std::vector<shape>&, const value&) const
{
return {};
}
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const
{
......@@ -32,7 +25,7 @@ struct mlir_compiler : compiler<mlir_compiler>
{
return [=](module& m, instruction_ref ins) {
auto mlir = insert_mlir(m, ins, co, ins->inputs());
m.replace_instruction(ins, mlir);
m.replace_instruction(ins, mlir);
};
}
};
......
......@@ -545,7 +545,7 @@ code_object_op compile_mlir(const context&, const module& m)
auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace)
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
auto co = mp.compile();
auto co = mp.compile();
co.output = m.get_output_shapes().front();
return co;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment