Commit 83b91229 authored by Paul's avatar Paul
Browse files

Formatting

parent 8e0fff81
...@@ -15,22 +15,26 @@ struct check_shapes ...@@ -15,22 +15,26 @@ struct check_shapes
const std::string name; const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {} check_shapes(const std::vector<shape>& s) : shapes(&s) {}
template<class Op> template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name()) {} check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name())
{
}
std::string prefix() const std::string prefix() const
{ {
if(name.empty()) return ""; if(name.empty())
else return name + ": "; return "";
else
return name + ": ";
} }
const check_shapes& has(std::size_t n) const const check_shapes& has(std::size_t n) const
{ {
assert(shapes != nullptr); assert(shapes != nullptr);
if(shapes->size() != n) if(shapes->size() != n)
RTG_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) + " but given " + RTG_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
std::to_string(shapes->size())); " but given " + std::to_string(shapes->size()));
return *this; return *this;
} }
...@@ -508,15 +512,12 @@ struct outline ...@@ -508,15 +512,12 @@ struct outline
argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; } argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; }
}; };
template<class T> template <class T>
struct check_context struct check_context
{ {
std::string name() const { return "check_context"; } std::string name() const { return "check_context"; }
shape compute_shape(std::vector<shape>) const shape compute_shape(std::vector<shape>) const { return {}; }
{ argument compute(context& ctx, shape, std::vector<argument>) const
return {};
}
argument compute(context& ctx, shape, std::vector<argument>) const
{ {
T* x = any_cast<T>(&ctx); T* x = any_cast<T>(&ctx);
if(x == nullptr) if(x == nullptr)
......
...@@ -29,8 +29,7 @@ rtg::argument run_gpu(std::string file) ...@@ -29,8 +29,7 @@ rtg::argument run_gpu(std::string file)
auto output = rtg::miopen::to_gpu(rtg::generate_argument(p.get_parameter_shape("output"))); auto output = rtg::miopen::to_gpu(rtg::generate_argument(p.get_parameter_shape("output")));
auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate); auto handle = rtg::miopen::make_obj<rtg::miopen::miopen_handle>(&miopenCreate);
auto out = p.eval( auto out = p.eval({{"Input3", input3}, {"output", output}});
{{"Input3", input3}, {"output", output}});
std::cout << p << std::endl; std::cout << p << std::endl;
return rtg::miopen::from_gpu(out); return rtg::miopen::from_gpu(out);
} }
......
...@@ -28,7 +28,7 @@ struct miopen_convolution ...@@ -28,7 +28,7 @@ struct miopen_convolution
} }
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx); auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto w_desc = make_tensor(args[1].get_shape()); auto w_desc = make_tensor(args[1].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
...@@ -80,7 +80,7 @@ struct miopen_pooling ...@@ -80,7 +80,7 @@ struct miopen_pooling
} }
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx); auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
...@@ -128,7 +128,7 @@ struct miopen_add ...@@ -128,7 +128,7 @@ struct miopen_add
} }
else else
{ {
auto& ctx = any_cast<miopen_context>(gctx); auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto a_desc = make_tensor(args[0].get_shape()); auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[1].get_shape()); auto b_desc = make_tensor(args[1].get_shape());
...@@ -185,7 +185,7 @@ struct miopen_relu ...@@ -185,7 +185,7 @@ struct miopen_relu
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
{ {
auto& ctx = any_cast<miopen_context>(gctx); auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment