"docs/en_US/vscode:/vscode.git/clone" did not exist on "0fd38debd5c6567e22bd38d2752d7a34306d7c5f"
Commit 9a3fc32d authored by Paul's avatar Paul
Browse files

Use const refs where possible

parent 61991b42
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_HIP_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_HIP_HPP
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <utility>
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
migraph::argument allocate_gpu(migraph::shape s, bool host = false); migraph::argument allocate_gpu(const migraph::shape& s, bool host = false);
migraph::argument to_gpu(migraph::argument arg, bool host = false); migraph::argument to_gpu(migraph::argument arg, bool host = false);
...@@ -16,12 +17,12 @@ struct hip_allocate ...@@ -16,12 +17,12 @@ struct hip_allocate
{ {
std::string tag{}; std::string tag{};
std::string name() const { return "hip::allocate"; } std::string name() const { return "hip::allocate"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.front(); return inputs.front();
} }
argument compute(context&, shape output_shape, std::vector<argument>) const argument compute(context&, const shape& output_shape, const std::vector<argument>&) const
{ {
return allocate_gpu(output_shape); return allocate_gpu(output_shape);
} }
...@@ -30,12 +31,12 @@ struct hip_allocate ...@@ -30,12 +31,12 @@ struct hip_allocate
struct hip_write struct hip_write
{ {
std::string name() const { return "hip::write"; } std::string name() const { return "hip::write"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.front(); return inputs.front();
} }
argument compute(context&, shape, std::vector<argument> args) const argument compute(context&, const shape&, const std::vector<argument>& args) const
{ {
return to_gpu(args.front()); return to_gpu(args.front());
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <migraph/gpu/rocblas.hpp> #include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/context.hpp> #include <migraph/gpu/context.hpp>
#include <utility>
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
...@@ -22,14 +23,14 @@ struct miopen_batch_norm_inference ...@@ -22,14 +23,14 @@ struct miopen_batch_norm_inference
std::string name() const { return "gpu::batch_norm_inference"; } std::string name() const { return "gpu::batch_norm_inference"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(6); check_shapes{inputs, *this}.has(6);
return op.compute_shape( return op.compute_shape(
{inputs.at(0), inputs.at(1), inputs.at(2), inputs.at(3), inputs.at(4)}); {inputs.at(0), inputs.at(1), inputs.at(2), inputs.at(3), inputs.at(4)});
} }
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const argument compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
...@@ -63,12 +64,12 @@ struct miopen_convolution ...@@ -63,12 +64,12 @@ struct miopen_convolution
miopenConvFwdAlgorithm_t algo{}; miopenConvFwdAlgorithm_t algo{};
std::string name() const { return "gpu::convolution"; } std::string name() const { return "gpu::convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(4).standard(); check_shapes{inputs, *this}.has(4).standard();
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const argument compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto w_desc = make_tensor(args[1].get_shape()); auto w_desc = make_tensor(args[1].get_shape());
...@@ -91,7 +92,7 @@ struct miopen_convolution ...@@ -91,7 +92,7 @@ struct miopen_convolution
return args[3]; return args[3];
} }
shape compile(context& ctx, shape output_shape, std::vector<instruction_ref> inputs) shape compile(context& ctx, const shape& output_shape, std::vector<instruction_ref> inputs)
{ {
shape workspace_shape{}; shape workspace_shape{};
auto x_desc = make_tensor(inputs[0]->get_shape()); auto x_desc = make_tensor(inputs[0]->get_shape());
...@@ -136,12 +137,12 @@ struct miopen_pooling ...@@ -136,12 +137,12 @@ struct miopen_pooling
shared<pooling_descriptor> pd; shared<pooling_descriptor> pd;
std::string name() const { return "gpu::pooling"; } std::string name() const { return "gpu::pooling"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(0)}); return op.compute_shape({inputs.at(0)});
} }
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const argument compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
...@@ -167,13 +168,13 @@ struct miopen_pooling ...@@ -167,13 +168,13 @@ struct miopen_pooling
struct miopen_add struct miopen_add
{ {
std::string name() const { return "gpu::add"; } std::string name() const { return "gpu::add"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3).not_broadcasted(); check_shapes{inputs, *this}.has(3).not_broadcasted();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const argument compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
if(args[1].get_shape().broadcasted()) if(args[1].get_shape().broadcasted())
{ {
...@@ -214,12 +215,12 @@ struct miopen_gemm ...@@ -214,12 +215,12 @@ struct miopen_gemm
{ {
gemm op; gemm op;
std::string name() const { return "gpu::convolution"; } std::string name() const { return "gpu::convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const argument compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
...@@ -253,14 +254,14 @@ struct miopen_contiguous ...@@ -253,14 +254,14 @@ struct miopen_contiguous
{ {
contiguous op; contiguous op;
std::string name() const { return "gpu::contiguous"; } std::string name() const { return "gpu::contiguous"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(0)}); return op.compute_shape({inputs.at(0)});
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, const std::vector<argument>& args) const
{ {
hip_contiguous(output_shape, args.at(0), args.at(1)); hip_contiguous(std::move(output_shape), args.at(0), args.at(1));
return args.at(1); return args.at(1);
} }
}; };
...@@ -269,13 +270,13 @@ struct miopen_relu ...@@ -269,13 +270,13 @@ struct miopen_relu
{ {
shared<activation_descriptor> ad; shared<activation_descriptor> ad;
std::string name() const { return "gpu::relu"; } std::string name() const { return "gpu::relu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).not_broadcasted(); check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1); return inputs.at(1);
} }
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const argument compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
...@@ -350,7 +351,7 @@ struct miopen_apply ...@@ -350,7 +351,7 @@ struct miopen_apply
else else
{ {
auto is = prog->add_outline(s); auto is = prog->add_outline(s);
auto result = prog->insert_instruction(ins, hip_allocate{tag}, is); auto result = prog->insert_instruction(ins, hip_allocate{std::move(tag)}, is);
return result; return result;
} }
} }
......
...@@ -6,7 +6,7 @@ struct sum_op ...@@ -6,7 +6,7 @@ struct sum_op
{ {
std::string name() const { return "sum"; } std::string name() const { return "sum"; }
migraph::argument migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const compute(migraph::context&, const migraph::shape&, std::vector<migraph::argument> args) const
{ {
migraph::argument result; migraph::argument result;
if(args.size() != 2) if(args.size() != 2)
...@@ -36,7 +36,7 @@ struct minus_op ...@@ -36,7 +36,7 @@ struct minus_op
{ {
std::string name() const { return "minus"; } std::string name() const { return "minus"; }
migraph::argument migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const compute(migraph::context&, const migraph::shape&, std::vector<migraph::argument> args) const
{ {
migraph::argument result; migraph::argument result;
if(args.size() != 2) if(args.size() != 2)
...@@ -66,7 +66,7 @@ struct pass_op ...@@ -66,7 +66,7 @@ struct pass_op
{ {
std::string name() const { return "pass"; } std::string name() const { return "pass"; }
migraph::argument migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const compute(migraph::context&, const migraph::shape&, std::vector<migraph::argument> args) const
{ {
if(args.empty()) if(args.empty())
return {}; return {};
...@@ -85,7 +85,7 @@ struct pass_standard_op ...@@ -85,7 +85,7 @@ struct pass_standard_op
{ {
std::string name() const { return "pass"; } std::string name() const { return "pass"; }
migraph::argument migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const compute(migraph::context&, const migraph::shape&, std::vector<migraph::argument> args) const
{ {
if(args.empty()) if(args.empty())
return {}; return {};
...@@ -109,12 +109,12 @@ struct nop ...@@ -109,12 +109,12 @@ struct nop
{ {
std::string name() const { return "nop"; } std::string name() const { return "nop"; }
migraph::argument migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument>) const compute(migraph::context&, const migraph::shape&, const std::vector<migraph::argument>&) const
{ {
return {}; return {};
} }
migraph::shape compute_shape(std::vector<migraph::shape>) const { return {}; } migraph::shape compute_shape(const std::vector<migraph::shape>&) const { return {}; }
}; };
inline migraph::literal get_2x2() inline migraph::literal get_2x2()
......
...@@ -141,7 +141,7 @@ bool throws(F f) ...@@ -141,7 +141,7 @@ bool throws(F f)
} }
template <class F, class Exception> template <class F, class Exception>
bool throws(F f, std::string msg = "") bool throws(F f, const std::string& msg = "")
{ {
try try
{ {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "test.hpp" #include "test.hpp"
template <class... Ts> template <class... Ts>
void expect_shape(migraph::shape expected, migraph::operation op, Ts... xs) void expect_shape(const migraph::shape& expected, const migraph::operation& op, Ts... xs)
{ {
migraph::program p; migraph::program p;
std::vector<migraph::shape> shapes{xs...}; std::vector<migraph::shape> shapes{xs...};
...@@ -24,7 +24,7 @@ void expect_shape(migraph::shape expected, migraph::operation op, Ts... xs) ...@@ -24,7 +24,7 @@ void expect_shape(migraph::shape expected, migraph::operation op, Ts... xs)
} }
template <class... Ts> template <class... Ts>
void throws_shape(migraph::operation op, Ts... xs) void throws_shape(const migraph::operation& op, Ts... xs)
{ {
migraph::program p; migraph::program p;
std::vector<migraph::shape> shapes{xs...}; std::vector<migraph::shape> shapes{xs...};
...@@ -46,7 +46,7 @@ struct always_false : std::false_type ...@@ -46,7 +46,7 @@ struct always_false : std::false_type
}; };
template <class... Ts> template <class... Ts>
void throws_shape(migraph::shape, Ts...) void throws_shape(const migraph::shape&, Ts...)
{ {
static_assert(always_false<Ts...>{}, static_assert(always_false<Ts...>{},
"An expected shape should not be passed to throws_shape function"); "An expected shape should not be passed to throws_shape function");
......
...@@ -8,12 +8,12 @@ struct simple_operation ...@@ -8,12 +8,12 @@ struct simple_operation
{ {
int data = 1; int data = 1;
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
migraph::shape compute_shape(std::vector<migraph::shape>) const migraph::shape compute_shape(const std::vector<migraph::shape>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
migraph::argument migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument>) const compute(migraph::context&, const migraph::shape&, const std::vector<migraph::argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
...@@ -27,12 +27,12 @@ struct simple_operation ...@@ -27,12 +27,12 @@ struct simple_operation
struct simple_operation_no_print struct simple_operation_no_print
{ {
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
migraph::shape compute_shape(std::vector<migraph::shape>) const migraph::shape compute_shape(const std::vector<migraph::shape>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
migraph::argument migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument>) const compute(migraph::context&, const migraph::shape&, const std::vector<migraph::argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
......
...@@ -25,7 +25,7 @@ struct operation ...@@ -25,7 +25,7 @@ struct operation
/// This is used to compute the resulting shape from an operation. If an /// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an /// operation cannot be run with input shapes, then it should throw an
/// exception. /// exception.
shape compute_shape(std::vector<shape> input) const; shape compute_shape(const std::vector<shape>& input) const;
/** /**
* @brief This performs the operation's computation * @brief This performs the operation's computation
* *
...@@ -37,7 +37,7 @@ struct operation ...@@ -37,7 +37,7 @@ struct operation
* @return Return an `argument` of the result computation. The `shape` of `argument` should be * @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape. * the same the `output` shape.
*/ */
argument compute(context& ctx, shape output, std::vector<argument> input) const; argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
/// An optional stream operator to print the operation. When this is not /// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name. /// implemented, it will just print the operation's name.
friend std::ostream& operator<<(std::ostream& os, const operation& op); friend std::ostream& operator<<(std::ostream& os, const operation& op);
...@@ -56,7 +56,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -56,7 +56,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream } // namespace operation_stream
template <class T> template <class T>
argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<argument> input) argument compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), output_shape, input);
} }
...@@ -64,8 +64,8 @@ argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<ar ...@@ -64,8 +64,8 @@ argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<ar
<% <%
interface('operation', interface('operation',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True), virtual('compute_shape', returns='shape', input='const std::vector<shape>&', const=True),
virtual('compute', returns='argument', ctx='context&', output='shape', input='std::vector<argument>', const=True, default='compute_op'), virtual('compute', returns='argument', ctx='context&', output='const shape&', input='const std::vector<argument>&', const=True, default='compute_op'),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='migraph::operation_stream::operator<<') friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='migraph::operation_stream::operator<<')
) )
%> %>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment