Commit 5c74135c authored by Paul's avatar Paul
Browse files

Call find during compilation

parent 25cb66bd
......@@ -2,12 +2,14 @@
#define MIGRAPH_GUARD_RTGLIB_MIOPEN_LOWERING_HPP
#include <migraph/program.hpp>
#include <migraph/gpu/context.hpp>
namespace migraph {
namespace gpu {
struct lowering
{
context ctx;
std::string name() const { return "gpu::lowering"; }
void apply(program& p) const;
};
......
......@@ -3,6 +3,7 @@
#include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp>
......@@ -62,6 +63,7 @@ struct miopen_convolution
{
convolution op;
shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{};
std::string name() const { return "gpu::convolution"; }
shape compute_shape(std::vector<shape> inputs) const
......@@ -76,22 +78,6 @@ struct miopen_convolution
auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0;
int algo_count;
miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(ctx.handle.get(),
x_desc.get(),
args[0].implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
y_desc.get(),
args[2].implicit(),
1,
&algo_count,
&perf,
nullptr,
0,
false);
miopenConvolutionForward(ctx.handle.get(),
&alpha,
x_desc.get(),
......@@ -99,7 +85,7 @@ struct miopen_convolution
w_desc.get(),
args[1].implicit(),
cd.get(),
perf.fwd_algo,
algo,
&beta,
y_desc.get(),
args[2].implicit(),
......@@ -107,6 +93,35 @@ struct miopen_convolution
0);
return args[2];
}
void compile(context& ctx, shape output_shape, std::vector<instruction_ref> inputs)
{
auto x_desc = make_tensor(inputs[0]->get_shape());
auto w_desc = make_tensor(inputs[1]->get_shape());
auto y_desc = make_tensor(output_shape);
auto x = to_gpu(generate_argument(inputs[0]->get_shape()));
auto w = to_gpu(generate_argument(inputs[1]->get_shape()));
auto y = to_gpu(generate_argument(output_shape));
int algo_count;
miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(ctx.handle.get(),
x_desc.get(),
x.implicit(),
w_desc.get(),
w.implicit(),
cd.get(),
y_desc.get(),
y.implicit(),
1,
&algo_count,
&perf,
nullptr,
0,
false);
algo = perf.fwd_algo;
}
};
struct miopen_pooling
......@@ -275,6 +290,7 @@ struct miopen_relu
struct miopen_apply
{
program* prog = nullptr;
context ctx{};
void apply()
{
......@@ -330,11 +346,12 @@ struct miopen_apply
void apply_convolution(instruction_ref ins)
{
auto&& op = any_cast<convolution>(ins->op);
auto cd = make_conv(op);
auto conv = miopen_convolution{op, make_conv(op)};
conv.compile(ctx, ins->result, ins->arguments);
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(ins,
miopen_convolution{op, std::move(cd)},
conv,
ins->arguments.at(0),
ins->arguments.at(1),
output);
......@@ -411,7 +428,7 @@ struct miopen_apply
}
};
void lowering::apply(program& p) const { miopen_apply{&p}.apply(); }
void lowering::apply(program& p) const { miopen_apply{&p, ctx}.apply(); }
} // namespace gpu
......
......@@ -10,14 +10,15 @@
namespace migraph {
namespace gpu {
std::vector<pass> target::get_passes(migraph::context&) const
std::vector<pass> target::get_passes(migraph::context& gctx) const
{
auto& ctx = any_cast<context>(gctx);
// clang-format off
return
{
auto_contiguous{},
simplify_reshapes{},
lowering{},
lowering{ctx},
write_literals{},
check_context<context>{},
dead_code_elimination{}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment