Commit 9a3fc32d authored by Paul's avatar Paul
Browse files

Use const refs where possible

parent 61991b42
......@@ -70,7 +70,7 @@ rocm_enable_clang_tidy(
-modernize-pass-by-value
-modernize-use-default-member-init
-modernize-use-transparent-functors
-performance-unnecessary-value-param
# -performance-unnecessary-value-param
-readability-braces-around-statements
-readability-else-after-return
-readability-named-parameter
......
......@@ -5,10 +5,11 @@
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <migraph/stringutils.hpp>
#include <utility>
namespace migraph {
bool try_compute_shape(operation op, std::vector<instruction_ref> args)
bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
try
{
......
......@@ -4,6 +4,7 @@
#include <migraph/shape.hpp>
#include <migraph/raw_data.hpp>
#include <functional>
#include <utility>
namespace migraph {
......@@ -18,16 +19,16 @@ struct argument : raw_data<argument>
{
argument() {}
argument(shape s) : m_shape(s)
argument(const shape& s) : m_shape(s)
{
std::vector<char> buffer(s.bytes());
// TODO: Move vector
data = [=]() mutable { return buffer.data(); };
}
argument(shape s, std::function<char*()> d) : data(d), m_shape(s) {}
argument(shape s, std::function<char*()> d) : data(std::move(d)), m_shape(std::move(s)) {}
template <class T>
argument(shape s, T* d) : data([d] { return reinterpret_cast<char*>(d); }), m_shape(s)
argument(shape s, T* d) : data([d] { return reinterpret_cast<char*>(d); }), m_shape(std::move(s))
{
}
......
......@@ -12,24 +12,24 @@ namespace builtin {
struct literal
{
std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { MIGRAPH_THROW("builtin"); }
argument compute(context&, shape, std::vector<argument>) const { MIGRAPH_THROW("builtin"); }
shape compute_shape(const std::vector<shape>&) const { MIGRAPH_THROW("builtin"); }
argument compute(context&, const shape&, const std::vector<argument>&) const { MIGRAPH_THROW("builtin"); }
};
struct outline
{
shape s;
std::string name() const { return "@outline"; }
shape compute_shape(std::vector<shape>) const { return s; }
argument compute(context&, shape, std::vector<argument>) const { MIGRAPH_THROW("builtin"); }
shape compute_shape(const std::vector<shape>&) const { return s; }
argument compute(context&, const shape&, const std::vector<argument>&) const { MIGRAPH_THROW("builtin"); }
};
struct param
{
std::string parameter;
std::string name() const { return "@param"; }
shape compute_shape(std::vector<shape>) const { MIGRAPH_THROW("builtin"); }
argument compute(context&, shape, std::vector<argument>) const { MIGRAPH_THROW("builtin"); }
shape compute_shape(const std::vector<shape>&) const { MIGRAPH_THROW("builtin"); }
argument compute(context&, const shape&, const std::vector<argument>&) const { MIGRAPH_THROW("builtin"); }
friend std::ostream& operator<<(std::ostream& os, const param& op)
{
os << op.name() << ":" << op.parameter;
......
......@@ -11,8 +11,8 @@ struct check_context
struct op
{
std::string name() const { return "check_context"; }
shape compute_shape(std::vector<shape>) const { return {}; }
argument compute(context& ctx, shape, std::vector<argument>) const
shape compute_shape(const std::vector<shape>&) const { return {}; }
argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{
T* x = any_cast<T>(&ctx);
if(x == nullptr)
......
......@@ -10,7 +10,7 @@ namespace migraph {
/// Represents exceptions that can be thrown by migraphlib
struct exception : std::runtime_error
{
exception(std::string msg = "") : std::runtime_error(msg) {}
exception(const std::string& msg = "") : std::runtime_error(msg) {}
};
/**
......@@ -20,7 +20,7 @@ struct exception : std::runtime_error
* @param message Custom message for the error
* @return Exceptions
*/
inline exception make_exception(std::string context, std::string message = "")
inline exception make_exception(const std::string& context, const std::string& message = "")
{
return {context + ": " + message};
}
......
......@@ -31,7 +31,7 @@ struct xorshf96_generator
};
template <class T>
std::vector<T> generate_tensor_data(migraph::shape s, std::mt19937::result_type)
std::vector<T> generate_tensor_data(const migraph::shape& s, std::mt19937::result_type)
{
std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), xorshf96_generator<T>{});
......
......@@ -8,10 +8,11 @@
#include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <string>
#include <utility>
namespace migraph {
shape compute_shape(operation op, std::vector<instruction_ref> args);
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
struct instruction
{
......@@ -25,14 +26,14 @@ struct instruction
instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {}
// internal
void replace(operation o, shape r, std::vector<instruction_ref> args)
void replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
op = o;
replace(std::move(r));
op = std::move(o);
replace(r);
replace(std::move(args));
}
void replace(shape r)
void replace(const shape& r)
{
if(r != result)
{
......@@ -155,7 +156,7 @@ inline void replace_argument(instruction_ref ins, instruction_ref old, instructi
// TODO: Move to a cpp file
// TODO: Use const ref for vector
inline shape compute_shape(operation op, std::vector<instruction_ref> args)
inline shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
......
......@@ -27,7 +27,7 @@ struct literal : raw_data<literal>
}
template <class T>
literal(shape s, const std::vector<T>& x)
literal(const shape& s, const std::vector<T>& x)
: buffer(std::make_unique<char[]>(s.bytes())), m_shape(s)
{
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
......@@ -35,7 +35,7 @@ struct literal : raw_data<literal>
}
template <class T>
literal(shape s, const std::initializer_list<T>& x)
literal(const shape& s, const std::initializer_list<T>& x)
: buffer(std::make_unique<char[]>(s.bytes())), m_shape(s)
{
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
......@@ -43,13 +43,13 @@ struct literal : raw_data<literal>
}
template <class Iterator>
literal(shape s, Iterator start, Iterator end)
literal(const shape& s, Iterator start, Iterator end)
: buffer(std::make_unique<char[]>(s.bytes())), m_shape(s)
{
fill(start, end);
}
literal(shape s, const char* x) : buffer(std::make_unique<char[]>(s.bytes())), m_shape(s)
literal(const shape& s, const char* x) : buffer(std::make_unique<char[]>(s.bytes())), m_shape(s)
{
std::copy(x, x + s.bytes(), buffer.get());
}
......
......@@ -25,7 +25,7 @@ struct operation
/// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an
/// exception.
shape compute_shape(std::vector<shape> input) const;
shape compute_shape(const std::vector<shape>& input) const;
/**
* @brief This performs the operation's computation
*
......@@ -37,7 +37,7 @@ struct operation
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape.
*/
argument compute(context& ctx, shape output, std::vector<argument> input) const;
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
/// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name.
friend std::ostream& operator<<(std::ostream& os, const operation& op);
......@@ -56,7 +56,8 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream
template <class T>
argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<argument> input)
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
......@@ -67,8 +68,8 @@ argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<ar
* struct operation
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(context& ctx,shape output,std::vector<argument> input) const;
* shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
*
......@@ -137,17 +138,16 @@ struct operation
return (*this).private_detail_te_get_handle().name();
}
shape compute_shape(std::vector<shape> input) const
shape compute_shape(const std::vector<shape>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute_shape(std::move(input));
return (*this).private_detail_te_get_handle().compute_shape(input);
}
argument compute(context& ctx, shape output, std::vector<argument> input) const
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(
ctx, std::move(output), std::move(input));
return (*this).private_detail_te_get_handle().compute(ctx, output, input);
}
friend std::ostream& operator<<(std::ostream& os, const operation& op)
......@@ -163,10 +163,11 @@ struct operation
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(context& ctx, shape output, std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual std::string name() const = 0;
virtual shape compute_shape(const std::vector<shape>& input) const = 0;
virtual argument
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -199,16 +200,18 @@ struct operation
std::string name() const override { return private_detail_te_value.name(); }
shape compute_shape(std::vector<shape> input) const override
shape compute_shape(const std::vector<shape>& input) const override
{
return private_detail_te_value.compute_shape(std::move(input));
return private_detail_te_value.compute_shape(input);
}
argument compute(context& ctx, shape output, std::vector<argument> input) const override
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input) const override
{
return compute_op(private_detail_te_value, ctx, std::move(output), std::move(input));
return compute_op(private_detail_te_value, ctx, output, input);
}
std::ostream& operator_shift_left(std::ostream& os) const override
......
......@@ -7,12 +7,13 @@
#include <migraph/stringutils.hpp>
#include <migraph/streamutils.hpp>
#include <cmath>
#include <utility>
namespace migraph {
struct not_computable
{
argument compute(context&, shape, std::vector<argument>) const
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
......@@ -41,7 +42,7 @@ struct batch_norm_inference
return inputs.front();
}
argument compute(context&, shape, std::vector<argument>) const
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
......@@ -114,7 +115,7 @@ struct convolution
}
}
argument compute(context&, shape, std::vector<argument>) const
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
......@@ -165,7 +166,7 @@ struct pooling
}};
}
argument compute(context&, shape, std::vector<argument>) const
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
......@@ -191,7 +192,7 @@ struct activation
return inputs.front();
}
argument compute(context&, shape, std::vector<argument>) const
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
......@@ -234,7 +235,7 @@ struct transpose
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
return {std::move(output_shape), std::move(args.front().data)};
}
friend std::ostream& operator<<(std::ostream& os, const transpose& op)
{
......@@ -259,7 +260,7 @@ struct contiguous
}
return {t, lens};
}
argument compute(context&, shape, std::vector<argument>) const
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
......@@ -306,7 +307,7 @@ struct reshape
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
return {std::move(output_shape), std::move(args.front().data)};
}
friend std::ostream& operator<<(std::ostream& os, const reshape& op)
......@@ -336,7 +337,7 @@ struct gemm
return {t, {a.lens()[0], b.lens()[1]}};
}
argument compute(context&, shape, std::vector<argument>) const
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
......@@ -356,7 +357,7 @@ struct unary
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(context&, shape, std::vector<argument>) const
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
......@@ -448,7 +449,7 @@ struct flatten
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
return {std::move(output_shape), std::move(args.front().data)};
}
friend std::ostream& operator<<(std::ostream& os, const flatten& op)
{
......@@ -488,7 +489,7 @@ struct broadcast
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.at(1).data)};
return {std::move(output_shape), std::move(args.at(1).data)};
}
friend std::ostream& operator<<(std::ostream& os, const broadcast& op)
{
......@@ -507,7 +508,7 @@ struct binary
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
argument compute(context&, shape, std::vector<argument>) const
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
......@@ -537,12 +538,12 @@ struct outline
{
shape s;
std::string name() const { return "outline"; }
shape compute_shape(std::vector<shape> inputs) const
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(0);
return s;
}
argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; }
argument compute(context&, const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
};
} // namespace migraph
......
......@@ -34,7 +34,7 @@ struct program
{
return add_instruction(op, {args...});
}
instruction_ref add_instruction(operation op, std::vector<instruction_ref> args);
instruction_ref add_instruction(const operation& op, std::vector<instruction_ref> args);
template <class... Ts>
instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args)
......@@ -42,7 +42,7 @@ struct program
return insert_instruction(ins, op, {args...});
}
instruction_ref
insert_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args);
insert_instruction(instruction_ref ins, const operation& op, std::vector<instruction_ref> args);
template <class... Ts>
instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args)
......@@ -50,7 +50,7 @@ struct program
return replace_instruction(ins, op, {args...});
}
instruction_ref
replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args);
replace_instruction(instruction_ref ins, const operation& op, std::vector<instruction_ref> args);
instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep);
......@@ -67,7 +67,7 @@ struct program
instruction_ref add_literal(literal l);
instruction_ref add_outline(shape s);
instruction_ref add_outline(const shape& s);
instruction_ref add_parameter(std::string name, shape s);
......
......@@ -29,7 +29,7 @@ inline bool ends_with(const std::string& value, const std::string& suffix)
}
template <class Strings>
inline std::string join_strings(Strings strings, std::string delim)
inline std::string join_strings(Strings strings, const std::string& delim)
{
auto it = strings.begin();
if(it == strings.end())
......@@ -57,7 +57,7 @@ inline bool starts_with(const std::string& value, const std::string& prefix)
return std::equal(prefix.begin(), prefix.end(), value.begin());
}
inline std::string remove_prefix(std::string s, std::string prefix)
inline std::string remove_prefix(std::string s, const std::string& prefix)
{
if(starts_with(s, prefix))
return s.substr(prefix.length());
......
......@@ -6,6 +6,7 @@
#include <migraph/requires.hpp>
#include <iostream>
#include <utility>
namespace migraph {
......@@ -14,7 +15,7 @@ struct tensor_view
{
using value_type = T;
tensor_view() : m_data(nullptr) {}
tensor_view(shape s, T* d) : m_data(d), m_shape(s) {}
tensor_view(shape s, T* d) : m_data(d), m_shape(std::move(s)) {}
const shape& get_shape() const { return this->m_shape; }
......
......@@ -20,7 +20,7 @@ auto reverse_int(unsigned int i)
(static_cast<unsigned int>(c3) << 8u) + c4;
};
std::vector<float> read_mnist_images(std::string full_path, int& number_of_images, int& image_size)
std::vector<float> read_mnist_images(const std::string& full_path, int& number_of_images, int& image_size)
{
using uchar = unsigned char;
......@@ -63,7 +63,7 @@ std::vector<float> read_mnist_images(std::string full_path, int& number_of_image
}
}
std::vector<int32_t> read_mnist_labels(std::string full_path, int& number_of_labels)
std::vector<int32_t> read_mnist_labels(const std::string& full_path, int& number_of_labels)
{
using uchar = unsigned char;
......
......@@ -6,6 +6,7 @@
#include <unordered_map>
#include <functional>
#include <array>
#include <utility>
#include <vector>
#include <migraph/fallthrough.hpp>
......@@ -27,7 +28,7 @@ struct unknown
else
return input.front();
}
argument compute(context&, shape, std::vector<argument>) const
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
......@@ -103,7 +104,7 @@ struct onnx_parser
}
instruction_ref
parse_conv(std::string, attribute_map attributes, std::vector<instruction_ref> args)
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
convolution op;
if(contains(attributes, "pads"))
......@@ -129,7 +130,7 @@ struct onnx_parser
}
instruction_ref
parse_pooling(std::string name, attribute_map attributes, std::vector<instruction_ref> args)
parse_pooling(const std::string& name, attribute_map attributes, std::vector<instruction_ref> args)
{
pooling op{name == "MaxPool" ? "max" : "average"};
if(contains(attributes, "pads"))
......@@ -144,11 +145,11 @@ struct onnx_parser
{
copy(attributes["kernel_shape"].ints(), op.lengths.begin());
}
return prog.add_instruction(op, args);
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_reshape(std::string, attribute_map attributes, std::vector<instruction_ref> args)
parse_reshape(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
reshape op;
if(args.size() == 1)
......@@ -165,7 +166,7 @@ struct onnx_parser
}
instruction_ref
parse_flatten(std::string, attribute_map attributes, std::vector<instruction_ref> args)
parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
uint64_t axis = 0;
if(contains(attributes, "axis"))
......@@ -176,14 +177,14 @@ struct onnx_parser
}
instruction_ref
parse_constant(std::string, attribute_map attributes, std::vector<instruction_ref>)
parse_constant(const std::string&, attribute_map attributes, const std::vector<instruction_ref>&)
{
literal v = parse_value(attributes.at("value"));
return prog.add_literal(v);
}
instruction_ref
parse_gemm(std::string, attribute_map attributes, std::vector<instruction_ref> args)
parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 1.0f;
float beta = 0.0f;
......@@ -219,7 +220,7 @@ struct onnx_parser
}
instruction_ref
parse_batchnorm(std::string, attribute_map attributes, std::vector<instruction_ref> args)
parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float epsilon = 1e-5f;
float momentum = 0.9f;
......@@ -244,7 +245,7 @@ struct onnx_parser
: batch_norm_inference::per_activation;
}
batch_norm_inference op{epsilon, momentum, bn_mode, is_test};
return prog.add_instruction(op, args);
return prog.add_instruction(op, std::move(args));
}
void parse_from(std::istream& is)
......@@ -293,7 +294,7 @@ struct onnx_parser
}
}
void parse_node(std::string name)
void parse_node(const std::string& name)
{
if(name.empty())
MIGRAPH_THROW("Onnx node must have a name");
......
......@@ -7,7 +7,7 @@
#include <migraph/generate.hpp>
#include <migraph/verify.hpp>
migraph::argument run_cpu(std::string file)
migraph::argument run_cpu(const std::string& file)
{
auto p = migraph::parse_onnx(file);
p.compile(migraph::cpu::cpu_target{});
......@@ -21,7 +21,7 @@ migraph::argument run_cpu(std::string file)
return out;
}
migraph::argument run_gpu(std::string file)
migraph::argument run_gpu(const std::string& file)
{
auto p = migraph::parse_onnx(file);
p.compile(migraph::gpu::target{});
......
......@@ -7,6 +7,7 @@
#include <iostream>
#include <sstream>
#include <algorithm>
#include <utility>
namespace migraph {
......@@ -76,12 +77,12 @@ program::program(program&&) noexcept = default;
program& program::operator=(program&&) noexcept = default;
program::~program() noexcept = default;
instruction_ref program::add_instruction(operation op, std::vector<instruction_ref> args)
instruction_ref program::add_instruction(const operation& op, std::vector<instruction_ref> args)
{
return insert_instruction(impl->instructions.end(), std::move(op), std::move(args));
return insert_instruction(impl->instructions.end(), op, std::move(args));
}
instruction_ref
program::insert_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args)
program::insert_instruction(instruction_ref ins, const operation& op, std::vector<instruction_ref> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
......@@ -97,7 +98,7 @@ program::insert_instruction(instruction_ref ins, operation op, std::vector<instr
}
instruction_ref
program::replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args)
program::replace_instruction(instruction_ref ins, const operation& op, std::vector<instruction_ref> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
......@@ -168,7 +169,7 @@ instruction_ref program::add_literal(literal l)
return impl->instructions.begin();
}
instruction_ref program::add_outline(shape s)
instruction_ref program::add_outline(const shape& s)
{
impl->instructions.push_front({builtin::outline{s}, s, {}});
return impl->instructions.begin();
......@@ -176,7 +177,7 @@ instruction_ref program::add_outline(shape s)
instruction_ref program::add_parameter(std::string name, shape s)
{
impl->instructions.push_front({builtin::param{std::move(name)}, s, {}});
impl->instructions.push_front({builtin::param{std::move(name)}, std::move(s), {}});
return impl->instructions.begin();
}
......@@ -317,7 +318,7 @@ argument generic_eval(const program& p,
argument program::eval(std::unordered_map<std::string, argument> params) const
{
return generic_eval(*this, this->impl->ctx, params, [](auto&, auto f) { return f(); });
return generic_eval(*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); });
}
double common_average(const std::vector<double>& v)
......
......@@ -7,6 +7,7 @@
#include <migraph/iterator_for.hpp>
#include <migraph/cpu/gemm.hpp>
#include <unordered_map>
#include <utility>
namespace migraph {
namespace cpu {
......@@ -39,9 +40,9 @@ struct cpu_batch_norm_inference
std::string name() const { return "cpu::batch_norm_inference"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument output{output_shape};
......@@ -95,7 +96,7 @@ struct cpu_convolution
convolution op;
std::string name() const { return "cpu::convolution"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
......@@ -161,8 +162,8 @@ struct cpu_pooling
pooling op;
std::string name() const { return "cpu::pooling_" + Op::name(); }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
......@@ -208,8 +209,8 @@ struct cpu_contiguous
{
contiguous op;
std::string name() const { return "cpu::contiguous"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
......@@ -225,9 +226,9 @@ struct cpu_gemm
{
gemm op;
std::string name() const { return "cpu::gemm"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
migemm(result, args[0], args[1], op.alpha, op.beta);
......@@ -357,8 +358,8 @@ struct cpu_unary
{
Op op;
std::string name() const { return op.name(); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
......@@ -373,8 +374,8 @@ struct cpu_unary
struct softmax2d
{
std::string name() const { return "cpu::softmax2d"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
......@@ -449,8 +450,8 @@ struct cpu_binary
{
Op op;
std::string name() const { return op.name(); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
......
......@@ -65,7 +65,7 @@ hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false)
return result;
}
argument allocate_gpu(shape s, bool host)
argument allocate_gpu(const shape& s, bool host)
{
auto p = share(allocate_gpu(s.bytes() + 1, host));
return {s, [p]() mutable { return reinterpret_cast<char*>(p.get()); }};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment