Commit 02b0095c authored by Paul's avatar Paul
Browse files

Move mlir compile to jit pipeline

parent 3325ac9c
......@@ -3,6 +3,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -11,6 +12,30 @@ struct module;
namespace gpu {
struct mlir_conv
{
operation op = make_op("convolution");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::mlir_conv"; }
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{
check_shapes{inputs, *this}.standard();
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size();
return op.compute_shape({inputs[n -2], inputs[n - 1]});
}
};
MIGRAPHX_REGISTER_OP(mlir_conv);
namespace {
struct find_conv_pointwise
{
......@@ -61,10 +86,7 @@ struct find_conv_pointwise
std::back_inserter(inputs),
[&](auto input) { return input != conv_ins; });
inputs.insert(inputs.end(), conv_ins->inputs().begin(), conv_ins->inputs().end());
inputs.push_back(m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(ins->get_shape())}})));
auto mlir = insert_mlir(m, ins, mm, inputs);
m.replace_instruction(ins, mlir);
m.replace_instruction(ins, mlir_conv{conv_ins->get_operator()}, inputs, ins->module_inputs());
}
};
} // namespace
......
......@@ -13,11 +13,13 @@ struct module;
namespace gpu {
std::string dump_mlir(const module& m);
code_object_op compile_mlir(const module& m);
code_object_op compile_mlir(const context& ctx, const module& m);
instruction_ref insert_mlir(module& m,
instruction_ref ins,
const module& mmlir,
const std::vector<instruction_ref>& inputs);
code_object_op co,
const std::vector<instruction_ref>& inputs);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/mlir.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct mlir_compiler : compiler<mlir_compiler>
{
std::vector<std::string> names() const
{
return {"gpu::mlir_conv"};
}
operation compile_op(context&, const std::vector<shape>&, const value&) const
{
return {};
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const
{
auto* smod = ins->module_inputs().front();
assert(smod->get_parameter_names().size() == ins->inputs().size() - 1);
return insert(compile_mlir(ctx, *smod));
}
compiler_replace insert(code_object_op co) const
{
return [=](module& m, instruction_ref ins) {
auto mlir = insert_mlir(m, ins, co, ins->inputs());
m.replace_instruction(ins, mlir);
};
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -535,7 +535,7 @@ std::string dump_mlir(const module& m)
return mlir_print(&mlirOperationPrint, mod_op);
}
code_object_op compile_mlir(const module& m)
code_object_op compile_mlir(const context&, const module& m)
{
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace)
......@@ -545,17 +545,16 @@ code_object_op compile_mlir(const module& m)
auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace)
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
return mp.compile();
auto co = mp.compile();
co.output = m.get_output_shapes().front();
return co;
}
instruction_ref insert_mlir(module& m,
instruction_ref ins,
const module& mmlir,
code_object_op co,
const std::vector<instruction_ref>& inputs)
{
assert(mmlir.get_parameter_names().size() == inputs.size() - 1);
auto co = compile_mlir(mmlir);
std::vector<instruction_ref> refs;
refs.reserve(inputs.size() * 15);
......@@ -594,7 +593,6 @@ instruction_ref insert_mlir(module& m,
// refs.push_back(get_literal(1)); // G
}
co.expected_inputs = to_shapes(refs);
co.output = mmlir.get_output_shapes().front();
co.output_arg = last;
return m.insert_instruction(ins, co, refs);
}
......
......@@ -69,7 +69,8 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir)
}));
inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front()));
migraphx::gpu::insert_mlir(*mm, mm->end(), mmlir, inputs);
migraphx::gpu::context ctx;
migraphx::gpu::insert_mlir(*mm, mm->end(), compile_mlir(ctx, mmlir), inputs);
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment