Unverified Commit 37f5df20 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Change compiler_replace to a class that stores the code objects directly (#1739)

Enable retrieving the code object to do tuning in the future.
parent 77042e30
...@@ -101,7 +101,7 @@ void compile_ops::apply(module& m) const ...@@ -101,7 +101,7 @@ void compile_ops::apply(module& m) const
par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); }); par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results) for(const auto& cr : results)
{ {
cr.replace(m, cr.ins); cr.replace.replace(m, cr.ins);
} }
} }
......
...@@ -38,7 +38,34 @@ namespace gpu { ...@@ -38,7 +38,34 @@ namespace gpu {
struct context; struct context;
using compiler_replace = std::function<void(module& m, instruction_ref ins)>; struct compiler_replace
{
compiler_replace() = default;
compiler_replace(const operation& op) : code_object{op} {}
template <class F>
compiler_replace(const operation& op, F f)
: code_object{op},
replace_fn([=](const compiler_replace& cr, module& m, instruction_ref ins) {
f(m, ins, cr.code_object);
})
{
}
operation code_object = {};
std::function<void(const compiler_replace& cr, module& m, instruction_ref ins)> replace_fn =
nullptr;
void replace(module& m, instruction_ref ins) const
{
if(replace_fn)
replace_fn(*this, m, ins);
else
m.replace_instruction(ins, code_object, ins->inputs());
}
};
using compiler_compile = std::function<compiler_replace(context&, instruction_ref, operation)>; using compiler_compile = std::function<compiler_replace(context&, instruction_ref, operation)>;
using compiler_compile_op = using compiler_compile_op =
std::function<operation(context&, const std::vector<shape>& inputs, const value&)>; std::function<operation(context&, const std::vector<shape>& inputs, const value&)>;
...@@ -78,11 +105,6 @@ using auto_register_compiler = auto_register<register_compiler_action, T>; ...@@ -78,11 +105,6 @@ using auto_register_compiler = auto_register<register_compiler_action, T>;
template <class Derived> template <class Derived>
struct compiler : auto_register_compiler<Derived> struct compiler : auto_register_compiler<Derived>
{ {
auto replace(const operation& op) const
{
return
[=](module& m, instruction_ref ins) { m.replace_instruction(ins, op, ins->inputs()); };
}
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; } operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
}; };
......
...@@ -108,7 +108,7 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -108,7 +108,7 @@ struct concat_compiler : compiler<concat_compiler>
v["post"] = "MIGRAPHX_LIFT(post_concat)"; v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel"; v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
} }
return replace(compile_op(ctx, to_shapes(ins->inputs()), v)); return compile_op(ctx, to_shapes(ins->inputs()), v);
} }
}; };
......
...@@ -80,7 +80,7 @@ struct gather_compiler : compiler<gather_compiler> ...@@ -80,7 +80,7 @@ struct gather_compiler : compiler<gather_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); return compile_op(ctx, to_shapes(ins->inputs()), op.to_value());
} }
}; };
......
...@@ -82,7 +82,7 @@ struct gathernd_compiler : compiler<gathernd_compiler> ...@@ -82,7 +82,7 @@ struct gathernd_compiler : compiler<gathernd_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); return compile_op(ctx, to_shapes(ins->inputs()), op.to_value());
} }
}; };
......
...@@ -122,7 +122,7 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -122,7 +122,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
v["kernel"] = v["kernel"] =
v["layernorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel"; v["layernorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel";
} }
return replace(compile_op(ctx, to_shapes(ins->inputs()), v)); return compile_op(ctx, to_shapes(ins->inputs()), v);
} }
}; };
......
...@@ -45,10 +45,10 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -45,10 +45,10 @@ struct mlir_compiler : compiler<mlir_compiler>
compiler_replace insert(code_object_op co) const compiler_replace insert(code_object_op co) const
{ {
return [co = std::move(co)](module& m, instruction_ref ins) { return {std::move(co), [](module& m, instruction_ref ins, const operation& op) {
auto mlir = insert_mlir(m, ins, co, ins->inputs()); auto mlir = insert_mlir(m, ins, any_cast<code_object_op>(op), ins->inputs());
m.replace_instruction(ins, mlir); m.replace_instruction(ins, mlir);
}; }};
} }
}; };
......
...@@ -92,7 +92,7 @@ struct pad_compiler : compiler<pad_compiler> ...@@ -92,7 +92,7 @@ struct pad_compiler : compiler<pad_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); return compile_op(ctx, to_shapes(ins->inputs()), op.to_value());
} }
}; };
} // namespace gpu } // namespace gpu
......
...@@ -93,10 +93,10 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -93,10 +93,10 @@ struct pointwise_compiler : compiler<pointwise_compiler>
{ {
if(contains({"layout", "contiguous"}, op.name())) if(contains({"layout", "contiguous"}, op.name()))
{ {
return replace(compile_op( return compile_op(
ctx, ctx,
to_shapes(ins->inputs()), to_shapes(ins->inputs()),
{{"lambda", "[](auto x) { return x; }"}, {"kernel", op.name() + "_kernel"}})); {{"lambda", "[](auto x) { return x; }"}, {"kernel", op.name() + "_kernel"}});
} }
else else
{ {
...@@ -105,10 +105,9 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -105,10 +105,9 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto pf = generate_pointwise(*pm, "inner_pointwise"); auto pf = generate_pointwise(*pm, "inner_pointwise");
std::string lambda = "MIGRAPHX_LIFT(inner_pointwise)"; std::string lambda = "MIGRAPHX_LIFT(inner_pointwise)";
auto kernel_name = generate_name_from_ops(*pm) + "_kernel"; auto kernel_name = generate_name_from_ops(*pm) + "_kernel";
return replace( return compile_op(ctx,
compile_op(ctx,
to_shapes(ins->inputs()), to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", pf}, {"kernel", kernel_name}})); {{"lambda", lambda}, {"preamble", pf}, {"kernel", kernel_name}});
} }
} }
}; };
......
...@@ -189,7 +189,7 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler> ...@@ -189,7 +189,7 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
v["read"] = r.read; v["read"] = r.read;
v["write"] = r.write; v["write"] = r.write;
v["init"] = r.init; v["init"] = r.init;
return replace(compile_op(ctx, to_shapes(ins->inputs()), v)); return compile_op(ctx, to_shapes(ins->inputs()), v);
} }
}; };
...@@ -285,7 +285,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler> ...@@ -285,7 +285,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
v["preamble"] = generate_reduce(*rm, "fused_reduce_op"); v["preamble"] = generate_reduce(*rm, "fused_reduce_op");
v["lambda"] = "MIGRAPHX_LIFT(fused_reduce_op)"; v["lambda"] = "MIGRAPHX_LIFT(fused_reduce_op)";
v["kernel"] = generate_name_from_ops(*rm) + "_kernel"; v["kernel"] = generate_name_from_ops(*rm) + "_kernel";
return replace(compile_op(ctx, to_shapes(ins->inputs()), v)); return compile_op(ctx, to_shapes(ins->inputs()), v);
} }
}; };
} // namespace gpu } // namespace gpu
......
...@@ -92,7 +92,7 @@ struct roialign_compiler : compiler<roialign_compiler> ...@@ -92,7 +92,7 @@ struct roialign_compiler : compiler<roialign_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); return compile_op(ctx, to_shapes(ins->inputs()), op.to_value());
} }
}; };
......
...@@ -85,15 +85,15 @@ struct scatternd_compiler : compiler<scatternd_compiler> ...@@ -85,15 +85,15 @@ struct scatternd_compiler : compiler<scatternd_compiler>
{{"reduction", reduction}})); {{"reduction", reduction}}));
} }
compiler_replace insert(const operation& op) const compiler_replace insert(const operation& co) const
{ {
return [=](module& m, instruction_ref ins) { return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs(); auto args = ins->inputs();
args.back() = args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back()); m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin()); args.erase(args.begin());
return m.replace_instruction(ins, op, args); return m.replace_instruction(ins, op, args);
}; }};
} }
}; };
......
...@@ -95,7 +95,7 @@ struct softmax_compiler : compiler<softmax_compiler> ...@@ -95,7 +95,7 @@ struct softmax_compiler : compiler<softmax_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); return compile_op(ctx, to_shapes(ins->inputs()), op.to_value());
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment