Commit 8e0fff81 authored by Paul's avatar Paul
Browse files

Store handle in context

parent 43ae3419
......@@ -62,6 +62,14 @@ struct context
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
private:
struct private_detail_te_handle_base_type
{
......@@ -111,13 +119,20 @@ struct context
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
......
......@@ -84,6 +84,14 @@ struct operation
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -183,13 +191,20 @@ struct operation
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
......
......@@ -12,14 +12,24 @@ namespace rtg {
struct check_shapes
{
const std::vector<shape>* shapes;
const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {}
template<class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name()) {}
std::string prefix() const
{
if(name.empty()) return "";
else return name + ": ";
}
const check_shapes& has(std::size_t n) const
{
assert(shapes != nullptr);
if(shapes->size() != n)
RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " +
RTG_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) + " but given " +
std::to_string(shapes->size()));
return *this;
}
......@@ -30,7 +40,7 @@ struct check_shapes
if(!shapes->empty())
{
if(shapes->front().lens().size() != n)
RTG_THROW("Only " + std::to_string(n) + "d supported");
RTG_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
}
return *this;
}
......@@ -38,28 +48,28 @@ struct check_shapes
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
RTG_THROW("Shapes do not match");
RTG_THROW(prefix() + "Shapes do not match");
return *this;
}
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
RTG_THROW("Types do not match");
RTG_THROW(prefix() + "Types do not match");
return *this;
}
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
RTG_THROW("Dimensions do not match");
RTG_THROW(prefix() + "Dimensions do not match");
return *this;
}
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.lens().size(); }))
RTG_THROW("Dimensions do not match");
RTG_THROW(prefix() + "Dimensions do not match");
return *this;
}
......@@ -101,7 +111,7 @@ struct convolution
std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_ndims().only_dims(4);
check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
......@@ -175,7 +185,7 @@ struct pooling
std::string name() const { return "pooling"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).only_dims(4);
check_shapes{inputs, *this}.has(1).only_dims(4);
const shape& input = inputs.at(0);
auto t = input.type();
......@@ -219,7 +229,7 @@ struct activation
std::string name() const { return "activation"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
......@@ -237,7 +247,7 @@ struct transpose
std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0);
auto input_lens = input.lens();
auto input_strides = input.strides();
......@@ -269,7 +279,7 @@ struct contiguous
std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
if(lens.size() < 2)
......@@ -287,7 +297,7 @@ struct reshape
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++)
......@@ -322,7 +332,7 @@ struct gemm
std::string name() const { return "gemm"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type();
check_shapes{inputs, *this}.has(2).same_type();
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
......@@ -492,12 +502,29 @@ struct outline
std::string name() const { return "outline"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(0);
check_shapes{inputs, *this}.has(0);
return s;
}
argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; }
};
template<class T>
struct check_context
{
std::string name() const { return "check_context"; }
shape compute_shape(std::vector<shape>) const
{
return {};
}
argument compute(context& ctx, shape, std::vector<argument>) const
{
T* x = any_cast<T>(&ctx);
if(x == nullptr)
RTG_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
return {};
}
};
} // namespace rtg
#endif
......@@ -73,6 +73,14 @@ struct target
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -150,13 +158,20 @@ struct target
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
......
......@@ -30,7 +30,7 @@ rtg::argument run_gpu(std::string file)
auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate);
auto out = p.eval(
{{"Input3", input3}, {"handle", {rtg::shape::any_type, handle.get()}}, {"output", output}});
{{"Input3", input3}, {"output", output}});
std::cout << p << std::endl;
return rtg::miopen::from_gpu(out);
}
......
......@@ -10,7 +10,7 @@ struct miopen_target
{
std::string name() const;
void apply(program& p) const;
context get_context() const { return {}; }
context get_context() const;
};
} // namespace miopen
......
......@@ -10,6 +10,11 @@
namespace rtg {
namespace miopen {
struct miopen_context
{
shared<miopen_handle> handle;
};
struct miopen_convolution
{
convolution op;
......@@ -18,46 +23,47 @@ struct miopen_convolution
std::string name() const { return "miopen::convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(4);
return op.compute_shape({inputs.at(1), inputs.at(2)});
check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{
auto x_desc = make_tensor(args[1].get_shape());
auto w_desc = make_tensor(args[2].get_shape());
auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape());
auto w_desc = make_tensor(args[1].get_shape());
auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0;
int algo_count;
miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(args[0].implicit(),
miopenFindConvolutionForwardAlgorithm(ctx.handle.get(),
x_desc.get(),
args[1].implicit(),
args[0].implicit(),
w_desc.get(),
args[2].implicit(),
args[1].implicit(),
cd.get(),
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
1,
&algo_count,
&perf,
nullptr,
0,
false);
miopenConvolutionForward(args[0].implicit(),
miopenConvolutionForward(ctx.handle.get(),
&alpha,
x_desc.get(),
args[1].implicit(),
args[0].implicit(),
w_desc.get(),
args[2].implicit(),
args[1].implicit(),
cd.get(),
perf.fwd_algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
nullptr,
0);
return args[3];
return args[2];
}
};
......@@ -69,29 +75,30 @@ struct miopen_pooling
std::string name() const { return "miopen::pooling"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(3);
check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(1)});
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{
auto x_desc = make_tensor(args[1].get_shape());
auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0;
miopenPoolingForward(args[0].implicit(),
miopenPoolingForward(ctx.handle.get(),
pd.get(),
&alpha,
x_desc.get(),
args[1].implicit(),
args[0].implicit(),
&beta,
y_desc.get(),
args[2].implicit(),
args[1].implicit(),
false,
nullptr,
0);
return args[2];
return args[1];
}
};
......@@ -100,17 +107,17 @@ struct miopen_add
std::string name() const { return "miopen::add"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(4);
return inputs.at(1);
check_shapes{inputs, *this}.has(3);
return inputs.at(0);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{
if(args[2].get_shape().broadcasted())
if(args[1].get_shape().broadcasted())
{
argument result{output_shape};
visit_all(result, from_gpu(args[1]), from_gpu(args[2]))(
visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
[&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
......@@ -121,22 +128,23 @@ struct miopen_add
}
else
{
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0;
auto a_desc = make_tensor(args[1].get_shape());
auto b_desc = make_tensor(args[2].get_shape());
auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[1].get_shape());
auto c_desc = make_tensor(output_shape);
miopenOpTensor(args[0].implicit(),
miopenOpTensor(ctx.handle.get(),
miopenTensorOpAdd,
&alpha,
a_desc.get(),
args[1].implicit(),
args[0].implicit(),
&alpha,
b_desc.get(),
args[2].implicit(),
args[1].implicit(),
&beta,
c_desc.get(),
args[3].implicit());
return args[3];
args[2].implicit());
return args[2];
}
}
};
......@@ -147,14 +155,14 @@ struct miopen_gemm
std::string name() const { return "miopen::convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(4);
return op.compute_shape({inputs.at(1), inputs.at(2)});
check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, from_gpu(args[1]), from_gpu(args[2]))(
visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
[&](auto output, auto input1, auto input2) {
dfor(input1.get_shape().lens()[0],
input2.get_shape().lens()[1],
......@@ -171,36 +179,36 @@ struct miopen_relu
std::string name() const { return "miopen::relu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(3);
check_shapes{inputs, *this}.has(2);
return inputs.at(1);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[1].get_shape());
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenActivationForward(args[0].implicit(),
miopenActivationForward(ctx.handle.get(),
ad.get(),
&alpha,
x_desc.get(),
args[1].implicit(),
args[0].implicit(),
&beta,
y_desc.get(),
args[2].implicit());
args[1].implicit());
return args[2];
return args[1];
}
};
struct miopen_apply
{
program* prog = nullptr;
instruction_ref handle{};
void apply()
{
handle = prog->add_parameter("handle", shape{shape::any_type});
prog->insert_instruction(prog->begin(), check_context<miopen_context>{});
for(auto it = prog->begin(); it != prog->end(); it++)
{
if(it->op.name() == "convolution")
......@@ -248,7 +256,6 @@ struct miopen_apply
prog->replace_instruction(ins,
miopen_convolution{op, std::move(cd)},
handle,
ins->arguments.at(0),
ins->arguments.at(1),
output);
......@@ -261,7 +268,7 @@ struct miopen_apply
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(
ins, miopen_pooling{op, std::move(pd)}, handle, ins->arguments.at(0), output);
ins, miopen_pooling{op, std::move(pd)}, ins->arguments.at(0), output);
}
void apply_activation(instruction_ref ins)
......@@ -272,7 +279,7 @@ struct miopen_apply
{
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(
ins, miopen_relu{std::move(ad)}, handle, ins->arguments.at(0), output);
ins, miopen_relu{std::move(ad)}, ins->arguments.at(0), output);
}
}
......@@ -280,7 +287,7 @@ struct miopen_apply
{
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(
ins, miopen_add{}, handle, ins->arguments.at(0), ins->arguments.at(1), output);
ins, miopen_add{}, ins->arguments.at(0), ins->arguments.at(1), output);
}
void apply_gemm(instruction_ref ins)
......@@ -288,7 +295,7 @@ struct miopen_apply
auto&& op = any_cast<gemm>(ins->op);
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(
ins, miopen_gemm{op}, handle, ins->arguments.at(0), ins->arguments.at(1), output);
ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output);
}
};
......@@ -296,6 +303,11 @@ std::string miopen_target::name() const { return "miopen"; }
void miopen_target::apply(program& p) const { miopen_apply{&p}.apply(); }
context miopen_target::get_context() const
{
return miopen_context{share(make_obj<miopen_handle>(&miopenCreate))};
}
} // namespace miopen
} // namespace rtg
......@@ -36,8 +36,6 @@ rtg::argument run_gpu()
}
m["output"] = rtg::miopen::to_gpu(rtg::generate_argument(p.get_parameter_shape("output")));
auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate);
m["handle"] = {rtg::shape::any_type, handle.get()};
return rtg::miopen::from_gpu(p.eval(m));
}
......
......@@ -63,6 +63,12 @@ struct ${struct_name}
nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty()) return typeid(std::nullptr_t);
else return private_detail_te_get_handle().type();
}
${nonvirtual_members}
private:
......@@ -118,11 +124,20 @@ private:
{}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type & private_detail_te_get_handle () const
{ return *private_detail_te_handle_mem_var; }
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type & private_detail_te_get_handle ()
{
assert(private_detail_te_handle_mem_var != nullptr);
if (!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment