Commit 8e0fff81 authored by Paul's avatar Paul
Browse files

Store handle in context

parent 43ae3419
...@@ -62,6 +62,14 @@ struct context ...@@ -62,6 +62,14 @@ struct context
: nullptr; : nullptr;
} }
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
...@@ -111,13 +119,20 @@ struct context ...@@ -111,13 +119,20 @@ struct context
} }
}; };
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{ {
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
......
...@@ -84,6 +84,14 @@ struct operation ...@@ -84,6 +84,14 @@ struct operation
: nullptr; : nullptr;
} }
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const std::string name() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -183,13 +191,20 @@ struct operation ...@@ -183,13 +191,20 @@ struct operation
} }
}; };
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{ {
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
......
...@@ -12,14 +12,24 @@ namespace rtg { ...@@ -12,14 +12,24 @@ namespace rtg {
struct check_shapes struct check_shapes
{ {
const std::vector<shape>* shapes; const std::vector<shape>* shapes;
const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {} check_shapes(const std::vector<shape>& s) : shapes(&s) {}
template<class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name()) {}
std::string prefix() const
{
if(name.empty()) return "";
else return name + ": ";
}
const check_shapes& has(std::size_t n) const const check_shapes& has(std::size_t n) const
{ {
assert(shapes != nullptr); assert(shapes != nullptr);
if(shapes->size() != n) if(shapes->size() != n)
RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " + RTG_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) + " but given " +
std::to_string(shapes->size())); std::to_string(shapes->size()));
return *this; return *this;
} }
...@@ -30,7 +40,7 @@ struct check_shapes ...@@ -30,7 +40,7 @@ struct check_shapes
if(!shapes->empty()) if(!shapes->empty())
{ {
if(shapes->front().lens().size() != n) if(shapes->front().lens().size() != n)
RTG_THROW("Only " + std::to_string(n) + "d supported"); RTG_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
} }
return *this; return *this;
} }
...@@ -38,28 +48,28 @@ struct check_shapes ...@@ -38,28 +48,28 @@ struct check_shapes
const check_shapes& same_shape() const const check_shapes& same_shape() const
{ {
if(!this->same([](const shape& s) { return s; })) if(!this->same([](const shape& s) { return s; }))
RTG_THROW("Shapes do not match"); RTG_THROW(prefix() + "Shapes do not match");
return *this; return *this;
} }
const check_shapes& same_type() const const check_shapes& same_type() const
{ {
if(!this->same([](const shape& s) { return s.type(); })) if(!this->same([](const shape& s) { return s.type(); }))
RTG_THROW("Types do not match"); RTG_THROW(prefix() + "Types do not match");
return *this; return *this;
} }
const check_shapes& same_dims() const const check_shapes& same_dims() const
{ {
if(!this->same([](const shape& s) { return s.lens(); })) if(!this->same([](const shape& s) { return s.lens(); }))
RTG_THROW("Dimensions do not match"); RTG_THROW(prefix() + "Dimensions do not match");
return *this; return *this;
} }
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(!this->same([](const shape& s) { return s.lens().size(); })) if(!this->same([](const shape& s) { return s.lens().size(); }))
RTG_THROW("Dimensions do not match"); RTG_THROW(prefix() + "Dimensions do not match");
return *this; return *this;
} }
...@@ -101,7 +111,7 @@ struct convolution ...@@ -101,7 +111,7 @@ struct convolution
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_ndims().only_dims(4); check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
...@@ -175,7 +185,7 @@ struct pooling ...@@ -175,7 +185,7 @@ struct pooling
std::string name() const { return "pooling"; } std::string name() const { return "pooling"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1).only_dims(4); check_shapes{inputs, *this}.has(1).only_dims(4);
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto t = input.type(); auto t = input.type();
...@@ -219,7 +229,7 @@ struct activation ...@@ -219,7 +229,7 @@ struct activation
std::string name() const { return "activation"; } std::string name() const { return "activation"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
return inputs.front(); return inputs.front();
} }
...@@ -237,7 +247,7 @@ struct transpose ...@@ -237,7 +247,7 @@ struct transpose
std::string name() const { return "transpose"; } std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
auto input_lens = input.lens(); auto input_lens = input.lens();
auto input_strides = input.strides(); auto input_strides = input.strides();
...@@ -269,7 +279,7 @@ struct contiguous ...@@ -269,7 +279,7 @@ struct contiguous
std::string name() const { return "contiguous"; } std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
auto lens = inputs.at(0).lens(); auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
if(lens.size() < 2) if(lens.size() < 2)
...@@ -287,7 +297,7 @@ struct reshape ...@@ -287,7 +297,7 @@ struct reshape
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens(); auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end()); std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++) for(std::size_t i = 0; i < dims.size(); i++)
...@@ -322,7 +332,7 @@ struct gemm ...@@ -322,7 +332,7 @@ struct gemm
std::string name() const { return "gemm"; } std::string name() const { return "gemm"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type(); check_shapes{inputs, *this}.has(2).same_type();
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
...@@ -492,12 +502,29 @@ struct outline ...@@ -492,12 +502,29 @@ struct outline
std::string name() const { return "outline"; } std::string name() const { return "outline"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(0); check_shapes{inputs, *this}.has(0);
return s; return s;
} }
argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; } argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; }
}; };
template<class T>
struct check_context
{
std::string name() const { return "check_context"; }
shape compute_shape(std::vector<shape>) const
{
return {};
}
argument compute(context& ctx, shape, std::vector<argument>) const
{
T* x = any_cast<T>(&ctx);
if(x == nullptr)
RTG_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
return {};
}
};
} // namespace rtg } // namespace rtg
#endif #endif
...@@ -73,6 +73,14 @@ struct target ...@@ -73,6 +73,14 @@ struct target
: nullptr; : nullptr;
} }
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const std::string name() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -150,13 +158,20 @@ struct target ...@@ -150,13 +158,20 @@ struct target
} }
}; };
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{ {
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
......
...@@ -30,7 +30,7 @@ rtg::argument run_gpu(std::string file) ...@@ -30,7 +30,7 @@ rtg::argument run_gpu(std::string file)
auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate); auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate);
auto out = p.eval( auto out = p.eval(
{{"Input3", input3}, {"handle", {rtg::shape::any_type, handle.get()}}, {"output", output}}); {{"Input3", input3}, {"output", output}});
std::cout << p << std::endl; std::cout << p << std::endl;
return rtg::miopen::from_gpu(out); return rtg::miopen::from_gpu(out);
} }
......
...@@ -10,7 +10,7 @@ struct miopen_target ...@@ -10,7 +10,7 @@ struct miopen_target
{ {
std::string name() const; std::string name() const;
void apply(program& p) const; void apply(program& p) const;
context get_context() const { return {}; } context get_context() const;
}; };
} // namespace miopen } // namespace miopen
......
...@@ -10,6 +10,11 @@ ...@@ -10,6 +10,11 @@
namespace rtg { namespace rtg {
namespace miopen { namespace miopen {
struct miopen_context
{
shared<miopen_handle> handle;
};
struct miopen_convolution struct miopen_convolution
{ {
convolution op; convolution op;
...@@ -18,46 +23,47 @@ struct miopen_convolution ...@@ -18,46 +23,47 @@ struct miopen_convolution
std::string name() const { return "miopen::convolution"; } std::string name() const { return "miopen::convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(4); check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(1), inputs.at(2)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto x_desc = make_tensor(args[1].get_shape()); auto& ctx = any_cast<miopen_context>(gctx);
auto w_desc = make_tensor(args[2].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto w_desc = make_tensor(args[1].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
int algo_count; int algo_count;
miopenConvAlgoPerf_t perf; miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(args[0].implicit(), miopenFindConvolutionForwardAlgorithm(ctx.handle.get(),
x_desc.get(), x_desc.get(),
args[1].implicit(), args[0].implicit(),
w_desc.get(), w_desc.get(),
args[2].implicit(), args[1].implicit(),
cd.get(), cd.get(),
y_desc.get(), y_desc.get(),
args[3].implicit(), args[2].implicit(),
1, 1,
&algo_count, &algo_count,
&perf, &perf,
nullptr, nullptr,
0, 0,
false); false);
miopenConvolutionForward(args[0].implicit(), miopenConvolutionForward(ctx.handle.get(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[1].implicit(), args[0].implicit(),
w_desc.get(), w_desc.get(),
args[2].implicit(), args[1].implicit(),
cd.get(), cd.get(),
perf.fwd_algo, perf.fwd_algo,
&beta, &beta,
y_desc.get(), y_desc.get(),
args[3].implicit(), args[2].implicit(),
nullptr, nullptr,
0); 0);
return args[3]; return args[2];
} }
}; };
...@@ -69,29 +75,30 @@ struct miopen_pooling ...@@ -69,29 +75,30 @@ struct miopen_pooling
std::string name() const { return "miopen::pooling"; } std::string name() const { return "miopen::pooling"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(3); check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(1)}); return op.compute_shape({inputs.at(1)});
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto x_desc = make_tensor(args[1].get_shape()); auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
miopenPoolingForward(args[0].implicit(), miopenPoolingForward(ctx.handle.get(),
pd.get(), pd.get(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[1].implicit(), args[0].implicit(),
&beta, &beta,
y_desc.get(), y_desc.get(),
args[2].implicit(), args[1].implicit(),
false, false,
nullptr, nullptr,
0); 0);
return args[2]; return args[1];
} }
}; };
...@@ -100,17 +107,17 @@ struct miopen_add ...@@ -100,17 +107,17 @@ struct miopen_add
std::string name() const { return "miopen::add"; } std::string name() const { return "miopen::add"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(4); check_shapes{inputs, *this}.has(3);
return inputs.at(1); return inputs.at(0);
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{ {
if(args[2].get_shape().broadcasted()) if(args[1].get_shape().broadcasted())
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, from_gpu(args[1]), from_gpu(args[2]))( visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
[&](auto output, auto input1, auto input2) { [&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = output(idx.begin(), idx.end()) =
...@@ -121,22 +128,23 @@ struct miopen_add ...@@ -121,22 +128,23 @@ struct miopen_add
} }
else else
{ {
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto a_desc = make_tensor(args[1].get_shape()); auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[2].get_shape()); auto b_desc = make_tensor(args[1].get_shape());
auto c_desc = make_tensor(output_shape); auto c_desc = make_tensor(output_shape);
miopenOpTensor(args[0].implicit(), miopenOpTensor(ctx.handle.get(),
miopenTensorOpAdd, miopenTensorOpAdd,
&alpha, &alpha,
a_desc.get(), a_desc.get(),
args[1].implicit(), args[0].implicit(),
&alpha, &alpha,
b_desc.get(), b_desc.get(),
args[2].implicit(), args[1].implicit(),
&beta, &beta,
c_desc.get(), c_desc.get(),
args[3].implicit()); args[2].implicit());
return args[3]; return args[2];
} }
} }
}; };
...@@ -147,14 +155,14 @@ struct miopen_gemm ...@@ -147,14 +155,14 @@ struct miopen_gemm
std::string name() const { return "miopen::convolution"; } std::string name() const { return "miopen::convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(4); check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(1), inputs.at(2)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, from_gpu(args[1]), from_gpu(args[2]))( visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
[&](auto output, auto input1, auto input2) { [&](auto output, auto input1, auto input2) {
dfor(input1.get_shape().lens()[0], dfor(input1.get_shape().lens()[0],
input2.get_shape().lens()[1], input2.get_shape().lens()[1],
...@@ -171,36 +179,36 @@ struct miopen_relu ...@@ -171,36 +179,36 @@ struct miopen_relu
std::string name() const { return "miopen::relu"; } std::string name() const { return "miopen::relu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(3); check_shapes{inputs, *this}.has(2);
return inputs.at(1); return inputs.at(1);
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[1].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
miopenActivationForward(args[0].implicit(), miopenActivationForward(ctx.handle.get(),
ad.get(), ad.get(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[1].implicit(), args[0].implicit(),
&beta, &beta,
y_desc.get(), y_desc.get(),
args[2].implicit()); args[1].implicit());
return args[2]; return args[1];
} }
}; };
struct miopen_apply struct miopen_apply
{ {
program* prog = nullptr; program* prog = nullptr;
instruction_ref handle{};
void apply() void apply()
{ {
handle = prog->add_parameter("handle", shape{shape::any_type}); prog->insert_instruction(prog->begin(), check_context<miopen_context>{});
for(auto it = prog->begin(); it != prog->end(); it++) for(auto it = prog->begin(); it != prog->end(); it++)
{ {
if(it->op.name() == "convolution") if(it->op.name() == "convolution")
...@@ -248,7 +256,6 @@ struct miopen_apply ...@@ -248,7 +256,6 @@ struct miopen_apply
prog->replace_instruction(ins, prog->replace_instruction(ins,
miopen_convolution{op, std::move(cd)}, miopen_convolution{op, std::move(cd)},
handle,
ins->arguments.at(0), ins->arguments.at(0),
ins->arguments.at(1), ins->arguments.at(1),
output); output);
...@@ -261,7 +268,7 @@ struct miopen_apply ...@@ -261,7 +268,7 @@ struct miopen_apply
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->result);
prog->replace_instruction( prog->replace_instruction(
ins, miopen_pooling{op, std::move(pd)}, handle, ins->arguments.at(0), output); ins, miopen_pooling{op, std::move(pd)}, ins->arguments.at(0), output);
} }
void apply_activation(instruction_ref ins) void apply_activation(instruction_ref ins)
...@@ -272,7 +279,7 @@ struct miopen_apply ...@@ -272,7 +279,7 @@ struct miopen_apply
{ {
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->result);
prog->replace_instruction( prog->replace_instruction(
ins, miopen_relu{std::move(ad)}, handle, ins->arguments.at(0), output); ins, miopen_relu{std::move(ad)}, ins->arguments.at(0), output);
} }
} }
...@@ -280,7 +287,7 @@ struct miopen_apply ...@@ -280,7 +287,7 @@ struct miopen_apply
{ {
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->result);
prog->replace_instruction( prog->replace_instruction(
ins, miopen_add{}, handle, ins->arguments.at(0), ins->arguments.at(1), output); ins, miopen_add{}, ins->arguments.at(0), ins->arguments.at(1), output);
} }
void apply_gemm(instruction_ref ins) void apply_gemm(instruction_ref ins)
...@@ -288,7 +295,7 @@ struct miopen_apply ...@@ -288,7 +295,7 @@ struct miopen_apply
auto&& op = any_cast<gemm>(ins->op); auto&& op = any_cast<gemm>(ins->op);
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->result);
prog->replace_instruction( prog->replace_instruction(
ins, miopen_gemm{op}, handle, ins->arguments.at(0), ins->arguments.at(1), output); ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output);
} }
}; };
...@@ -296,6 +303,11 @@ std::string miopen_target::name() const { return "miopen"; } ...@@ -296,6 +303,11 @@ std::string miopen_target::name() const { return "miopen"; }
void miopen_target::apply(program& p) const { miopen_apply{&p}.apply(); } void miopen_target::apply(program& p) const { miopen_apply{&p}.apply(); }
context miopen_target::get_context() const
{
return miopen_context{share(make_obj<miopen_handle>(&miopenCreate))};
}
} // namespace miopen } // namespace miopen
} // namespace rtg } // namespace rtg
...@@ -36,8 +36,6 @@ rtg::argument run_gpu() ...@@ -36,8 +36,6 @@ rtg::argument run_gpu()
} }
m["output"] = rtg::miopen::to_gpu(rtg::generate_argument(p.get_parameter_shape("output"))); m["output"] = rtg::miopen::to_gpu(rtg::generate_argument(p.get_parameter_shape("output")));
auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate);
m["handle"] = {rtg::shape::any_type, handle.get()};
return rtg::miopen::from_gpu(p.eval(m)); return rtg::miopen::from_gpu(p.eval(m));
} }
......
...@@ -63,6 +63,12 @@ struct ${struct_name} ...@@ -63,6 +63,12 @@ struct ${struct_name}
nullptr; nullptr;
} }
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty()) return typeid(std::nullptr_t);
else return private_detail_te_get_handle().type();
}
${nonvirtual_members} ${nonvirtual_members}
private: private:
...@@ -118,11 +124,20 @@ private: ...@@ -118,11 +124,20 @@ private:
{} {}
}; };
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type & private_detail_te_get_handle () const const private_detail_te_handle_base_type & private_detail_te_get_handle () const
{ return *private_detail_te_handle_mem_var; } {
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type & private_detail_te_get_handle () private_detail_te_handle_base_type & private_detail_te_get_handle ()
{ {
assert(private_detail_te_handle_mem_var != nullptr);
if (!private_detail_te_handle_mem_var.unique()) if (!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment