"configs/datasets/vscode:/vscode.git/clone" did not exist on "d501710155b0b8bb5808fff8d3107e0589650026"
Unverified Commit 37f5df20 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Change compiler_replace to a class that stores the code objects directly (#1739)

Enable retrieving the code object to do tuning in the future.
parent 77042e30
......@@ -101,7 +101,7 @@ void compile_ops::apply(module& m) const
par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results)
{
cr.replace(m, cr.ins);
cr.replace.replace(m, cr.ins);
}
}
......
......@@ -38,7 +38,34 @@ namespace gpu {
struct context;
using compiler_replace = std::function<void(module& m, instruction_ref ins)>;
struct compiler_replace
{
compiler_replace() = default;
compiler_replace(const operation& op) : code_object{op} {}
template <class F>
compiler_replace(const operation& op, F f)
: code_object{op},
replace_fn([=](const compiler_replace& cr, module& m, instruction_ref ins) {
f(m, ins, cr.code_object);
})
{
}
operation code_object = {};
std::function<void(const compiler_replace& cr, module& m, instruction_ref ins)> replace_fn =
nullptr;
void replace(module& m, instruction_ref ins) const
{
if(replace_fn)
replace_fn(*this, m, ins);
else
m.replace_instruction(ins, code_object, ins->inputs());
}
};
using compiler_compile = std::function<compiler_replace(context&, instruction_ref, operation)>;
using compiler_compile_op =
std::function<operation(context&, const std::vector<shape>& inputs, const value&)>;
......@@ -78,11 +105,6 @@ using auto_register_compiler = auto_register<register_compiler_action, T>;
template <class Derived>
struct compiler : auto_register_compiler<Derived>
{
auto replace(const operation& op) const
{
return
[=](module& m, instruction_ref ins) { m.replace_instruction(ins, op, ins->inputs()); };
}
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
};
......
......@@ -108,7 +108,7 @@ struct concat_compiler : compiler<concat_compiler>
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
return compile_op(ctx, to_shapes(ins->inputs()), v);
}
};
......
......@@ -80,7 +80,7 @@ struct gather_compiler : compiler<gather_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
return compile_op(ctx, to_shapes(ins->inputs()), op.to_value());
}
};
......
......@@ -82,7 +82,7 @@ struct gathernd_compiler : compiler<gathernd_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
return compile_op(ctx, to_shapes(ins->inputs()), op.to_value());
}
};
......
......@@ -122,7 +122,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
v["kernel"] =
v["layernorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
return compile_op(ctx, to_shapes(ins->inputs()), v);
}
};
......
......@@ -45,10 +45,10 @@ struct mlir_compiler : compiler<mlir_compiler>
compiler_replace insert(code_object_op co) const
{
return [co = std::move(co)](module& m, instruction_ref ins) {
auto mlir = insert_mlir(m, ins, co, ins->inputs());
m.replace_instruction(ins, mlir);
};
return {std::move(co), [](module& m, instruction_ref ins, const operation& op) {
auto mlir = insert_mlir(m, ins, any_cast<code_object_op>(op), ins->inputs());
m.replace_instruction(ins, mlir);
}};
}
};
......
......@@ -92,7 +92,7 @@ struct pad_compiler : compiler<pad_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
return compile_op(ctx, to_shapes(ins->inputs()), op.to_value());
}
};
} // namespace gpu
......
......@@ -93,10 +93,10 @@ struct pointwise_compiler : compiler<pointwise_compiler>
{
if(contains({"layout", "contiguous"}, op.name()))
{
return replace(compile_op(
return compile_op(
ctx,
to_shapes(ins->inputs()),
{{"lambda", "[](auto x) { return x; }"}, {"kernel", op.name() + "_kernel"}}));
{{"lambda", "[](auto x) { return x; }"}, {"kernel", op.name() + "_kernel"}});
}
else
{
......@@ -105,10 +105,9 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto pf = generate_pointwise(*pm, "inner_pointwise");
std::string lambda = "MIGRAPHX_LIFT(inner_pointwise)";
auto kernel_name = generate_name_from_ops(*pm) + "_kernel";
return replace(
compile_op(ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", pf}, {"kernel", kernel_name}}));
return compile_op(ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", pf}, {"kernel", kernel_name}});
}
}
};
......
......@@ -189,7 +189,7 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
v["read"] = r.read;
v["write"] = r.write;
v["init"] = r.init;
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
return compile_op(ctx, to_shapes(ins->inputs()), v);
}
};
......@@ -285,7 +285,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
v["preamble"] = generate_reduce(*rm, "fused_reduce_op");
v["lambda"] = "MIGRAPHX_LIFT(fused_reduce_op)";
v["kernel"] = generate_name_from_ops(*rm) + "_kernel";
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
return compile_op(ctx, to_shapes(ins->inputs()), v);
}
};
} // namespace gpu
......
......@@ -92,7 +92,7 @@ struct roialign_compiler : compiler<roialign_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
return compile_op(ctx, to_shapes(ins->inputs()), op.to_value());
}
};
......
......@@ -85,15 +85,15 @@ struct scatternd_compiler : compiler<scatternd_compiler>
{{"reduction", reduction}}));
}
compiler_replace insert(const operation& op) const
compiler_replace insert(const operation& co) const
{
return [=](module& m, instruction_ref ins) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
};
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
};
......
......@@ -95,7 +95,7 @@ struct softmax_compiler : compiler<softmax_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
return compile_op(ctx, to_shapes(ins->inputs()), op.to_value());
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment