Commit 5c74135c authored by Paul's avatar Paul
Browse files

Call find during compilation

parent 25cb66bd
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
#define MIGRAPH_GUARD_RTGLIB_MIOPEN_LOWERING_HPP #define MIGRAPH_GUARD_RTGLIB_MIOPEN_LOWERING_HPP
#include <migraph/program.hpp> #include <migraph/program.hpp>
#include <migraph/gpu/context.hpp>
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
struct lowering struct lowering
{ {
context ctx;
std::string name() const { return "gpu::lowering"; } std::string name() const { return "gpu::lowering"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraph/manage_ptr.hpp> #include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <migraph/shape_for_each.hpp> #include <migraph/shape_for_each.hpp>
#include <migraph/gpu/miopen.hpp> #include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp> #include <migraph/gpu/hip.hpp>
...@@ -62,6 +63,7 @@ struct miopen_convolution ...@@ -62,6 +63,7 @@ struct miopen_convolution
{ {
convolution op; convolution op;
shared<convolution_descriptor> cd; shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{};
std::string name() const { return "gpu::convolution"; } std::string name() const { return "gpu::convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
...@@ -76,22 +78,6 @@ struct miopen_convolution ...@@ -76,22 +78,6 @@ struct miopen_convolution
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
int algo_count;
miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(ctx.handle.get(),
x_desc.get(),
args[0].implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
y_desc.get(),
args[2].implicit(),
1,
&algo_count,
&perf,
nullptr,
0,
false);
miopenConvolutionForward(ctx.handle.get(), miopenConvolutionForward(ctx.handle.get(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
...@@ -99,7 +85,7 @@ struct miopen_convolution ...@@ -99,7 +85,7 @@ struct miopen_convolution
w_desc.get(), w_desc.get(),
args[1].implicit(), args[1].implicit(),
cd.get(), cd.get(),
perf.fwd_algo, algo,
&beta, &beta,
y_desc.get(), y_desc.get(),
args[2].implicit(), args[2].implicit(),
...@@ -107,6 +93,35 @@ struct miopen_convolution ...@@ -107,6 +93,35 @@ struct miopen_convolution
0); 0);
return args[2]; return args[2];
} }
void compile(context& ctx, shape output_shape, std::vector<instruction_ref> inputs)
{
auto x_desc = make_tensor(inputs[0]->get_shape());
auto w_desc = make_tensor(inputs[1]->get_shape());
auto y_desc = make_tensor(output_shape);
auto x = to_gpu(generate_argument(inputs[0]->get_shape()));
auto w = to_gpu(generate_argument(inputs[1]->get_shape()));
auto y = to_gpu(generate_argument(output_shape));
int algo_count;
miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(ctx.handle.get(),
x_desc.get(),
x.implicit(),
w_desc.get(),
w.implicit(),
cd.get(),
y_desc.get(),
y.implicit(),
1,
&algo_count,
&perf,
nullptr,
0,
false);
algo = perf.fwd_algo;
}
}; };
struct miopen_pooling struct miopen_pooling
...@@ -275,6 +290,7 @@ struct miopen_relu ...@@ -275,6 +290,7 @@ struct miopen_relu
struct miopen_apply struct miopen_apply
{ {
program* prog = nullptr; program* prog = nullptr;
context ctx{};
void apply() void apply()
{ {
...@@ -330,11 +346,12 @@ struct miopen_apply ...@@ -330,11 +346,12 @@ struct miopen_apply
void apply_convolution(instruction_ref ins) void apply_convolution(instruction_ref ins)
{ {
auto&& op = any_cast<convolution>(ins->op); auto&& op = any_cast<convolution>(ins->op);
auto cd = make_conv(op); auto conv = miopen_convolution{op, make_conv(op)};
conv.compile(ctx, ins->result, ins->arguments);
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(ins, prog->replace_instruction(ins,
miopen_convolution{op, std::move(cd)}, conv,
ins->arguments.at(0), ins->arguments.at(0),
ins->arguments.at(1), ins->arguments.at(1),
output); output);
...@@ -411,7 +428,7 @@ struct miopen_apply ...@@ -411,7 +428,7 @@ struct miopen_apply
} }
}; };
void lowering::apply(program& p) const { miopen_apply{&p}.apply(); } void lowering::apply(program& p) const { miopen_apply{&p, ctx}.apply(); }
} // namespace gpu } // namespace gpu
......
...@@ -10,14 +10,15 @@ ...@@ -10,14 +10,15 @@
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
std::vector<pass> target::get_passes(migraph::context&) const std::vector<pass> target::get_passes(migraph::context& gctx) const
{ {
auto& ctx = any_cast<context>(gctx);
// clang-format off // clang-format off
return return
{ {
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
lowering{}, lowering{ctx},
write_literals{}, write_literals{},
check_context<context>{}, check_context<context>{},
dead_code_elimination{} dead_code_elimination{}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment