Commit 038a4c52 authored by wsttiger's avatar wsttiger
Browse files

Merged from master still debugging resnet

parents 06cc4f8f 905d4ab0
CheckOptions: CheckOptions:
- key: modernize-loop-convert.MinConfidence
value: risky
- key: modernize-loop-convert.NamingStyle - key: modernize-loop-convert.NamingStyle
value: lower_case value: lower_case
- key: readability-function-size.BranchThreshold - key: readability-function-size.BranchThreshold
......
...@@ -36,9 +36,7 @@ include(ROCMClangTidy) ...@@ -36,9 +36,7 @@ include(ROCMClangTidy)
rocm_enable_clang_tidy( rocm_enable_clang_tidy(
CHECKS CHECKS
* *
-cert-env33-c
-android-cloexec-fopen -android-cloexec-fopen
-cert-msc50-cpp
-clang-analyzer-alpha.core.CastToStruct -clang-analyzer-alpha.core.CastToStruct
-clang-analyzer-optin.performance.Padding -clang-analyzer-optin.performance.Padding
-clang-diagnostic-deprecated-declarations -clang-diagnostic-deprecated-declarations
...@@ -72,7 +70,6 @@ rocm_enable_clang_tidy( ...@@ -72,7 +70,6 @@ rocm_enable_clang_tidy(
-modernize-pass-by-value -modernize-pass-by-value
-modernize-use-default-member-init -modernize-use-default-member-init
-modernize-use-transparent-functors -modernize-use-transparent-functors
-performance-unnecessary-value-param
-readability-braces-around-statements -readability-braces-around-statements
-readability-else-after-return -readability-else-after-return
-readability-named-parameter -readability-named-parameter
......
...@@ -5,10 +5,11 @@ ...@@ -5,10 +5,11 @@
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp> #include <migraph/ranges.hpp>
#include <migraph/stringutils.hpp> #include <migraph/stringutils.hpp>
#include <utility>
namespace migraph { namespace migraph {
bool try_compute_shape(operation op, std::vector<instruction_ref> args) bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{ {
try try
{ {
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/raw_data.hpp> #include <migraph/raw_data.hpp>
#include <functional> #include <functional>
#include <utility>
namespace migraph { namespace migraph {
...@@ -18,16 +19,17 @@ struct argument : raw_data<argument> ...@@ -18,16 +19,17 @@ struct argument : raw_data<argument>
{ {
argument() {} argument() {}
argument(shape s) : m_shape(s) argument(const shape& s) : m_shape(s)
{ {
std::vector<char> buffer(s.bytes()); std::vector<char> buffer(s.bytes());
// TODO: Move vector // TODO: Move vector
data = [=]() mutable { return buffer.data(); }; data = [=]() mutable { return buffer.data(); };
} }
argument(shape s, std::function<char*()> d) : data(d), m_shape(s) {} argument(shape s, std::function<char*()> d) : data(std::move(d)), m_shape(std::move(s)) {}
template <class T> template <class T>
argument(shape s, T* d) : data([d] { return reinterpret_cast<char*>(d); }), m_shape(s) argument(shape s, T* d)
: data([d] { return reinterpret_cast<char*>(d); }), m_shape(std::move(s))
{ {
} }
......
...@@ -12,24 +12,33 @@ namespace builtin { ...@@ -12,24 +12,33 @@ namespace builtin {
struct literal struct literal
{ {
std::string name() const { return "@literal"; } std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { MIGRAPH_THROW("builtin"); } shape compute_shape(const std::vector<shape>&) const { MIGRAPH_THROW("builtin"); }
argument compute(context&, shape, std::vector<argument>) const { MIGRAPH_THROW("builtin"); } argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("builtin");
}
}; };
struct outline struct outline
{ {
shape s; shape s;
std::string name() const { return "@outline"; } std::string name() const { return "@outline"; }
shape compute_shape(std::vector<shape>) const { return s; } shape compute_shape(const std::vector<shape>&) const { return s; }
argument compute(context&, shape, std::vector<argument>) const { MIGRAPH_THROW("builtin"); } argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("builtin");
}
}; };
struct param struct param
{ {
std::string parameter; std::string parameter;
std::string name() const { return "@param"; } std::string name() const { return "@param"; }
shape compute_shape(std::vector<shape>) const { MIGRAPH_THROW("builtin"); } shape compute_shape(const std::vector<shape>&) const { MIGRAPH_THROW("builtin"); }
argument compute(context&, shape, std::vector<argument>) const { MIGRAPH_THROW("builtin"); } argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("builtin");
}
friend std::ostream& operator<<(std::ostream& os, const param& op) friend std::ostream& operator<<(std::ostream& os, const param& op)
{ {
os << op.name() << ":" << op.parameter; os << op.name() << ":" << op.parameter;
......
...@@ -11,8 +11,8 @@ struct check_context ...@@ -11,8 +11,8 @@ struct check_context
struct op struct op
{ {
std::string name() const { return "check_context"; } std::string name() const { return "check_context"; }
shape compute_shape(std::vector<shape>) const { return {}; } shape compute_shape(const std::vector<shape>&) const { return {}; }
argument compute(context& ctx, shape, std::vector<argument>) const argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{ {
T* x = any_cast<T>(&ctx); T* x = any_cast<T>(&ctx);
if(x == nullptr) if(x == nullptr)
......
...@@ -10,7 +10,7 @@ namespace migraph { ...@@ -10,7 +10,7 @@ namespace migraph {
/// Represents exceptions that can be thrown by migraphlib /// Represents exceptions that can be thrown by migraphlib
struct exception : std::runtime_error struct exception : std::runtime_error
{ {
exception(std::string msg = "") : std::runtime_error(msg) {} exception(const std::string& msg = "") : std::runtime_error(msg) {}
}; };
/** /**
...@@ -20,7 +20,7 @@ struct exception : std::runtime_error ...@@ -20,7 +20,7 @@ struct exception : std::runtime_error
* @param message Custom message for the error * @param message Custom message for the error
* @return Exceptions * @return Exceptions
*/ */
inline exception make_exception(std::string context, std::string message = "") inline exception make_exception(const std::string& context, const std::string& message = "")
{ {
return {context + ": " + message}; return {context + ": " + message};
} }
......
...@@ -8,12 +8,33 @@ ...@@ -8,12 +8,33 @@
namespace migraph { namespace migraph {
template <class T> template <class T>
std::vector<T> generate_tensor_data(migraph::shape s, std::mt19937::result_type seed = 0) struct xorshf96_generator
{
unsigned long max = 31;
unsigned long x = 123456789;
unsigned long y = 362436069;
unsigned long z = 521288629;
constexpr T operator()() noexcept
{
x ^= x << 16U;
x ^= x >> 5U;
x ^= x << 1U;
unsigned long t = x;
x = y;
y = z;
z = t ^ x ^ y;
return z % max;
}
};
template <class T>
std::vector<T> generate_tensor_data(const migraph::shape& s, std::mt19937::result_type)
{ {
std::vector<T> result(s.elements()); std::vector<T> result(s.elements());
std::mt19937 engine{seed}; std::generate(result.begin(), result.end(), xorshf96_generator<T>{});
std::uniform_real_distribution<> dist;
std::generate(result.begin(), result.end(), [&] { return dist(engine); });
return result; return result;
} }
......
...@@ -8,10 +8,11 @@ ...@@ -8,10 +8,11 @@
#include <migraph/operation.hpp> #include <migraph/operation.hpp>
#include <migraph/erase.hpp> #include <migraph/erase.hpp>
#include <string> #include <string>
#include <utility>
namespace migraph { namespace migraph {
shape compute_shape(operation op, std::vector<instruction_ref> args); shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
struct instruction struct instruction
{ {
...@@ -25,14 +26,14 @@ struct instruction ...@@ -25,14 +26,14 @@ struct instruction
instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {} instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {}
// internal // internal
void replace(operation o, shape r, std::vector<instruction_ref> args) void replace(operation o, const shape& r, std::vector<instruction_ref> args)
{ {
op = o; op = std::move(o);
replace(std::move(r)); replace(r);
replace(std::move(args)); replace(std::move(args));
} }
void replace(shape r) void replace(const shape& r)
{ {
if(r != result) if(r != result)
{ {
...@@ -155,7 +156,7 @@ inline void replace_argument(instruction_ref ins, instruction_ref old, instructi ...@@ -155,7 +156,7 @@ inline void replace_argument(instruction_ref ins, instruction_ref old, instructi
// TODO: Move to a cpp file // TODO: Move to a cpp file
// TODO: Use const ref for vector // TODO: Use const ref for vector
inline shape compute_shape(operation op, std::vector<instruction_ref> args) inline shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{ {
std::vector<shape> shapes(args.size()); std::vector<shape> shapes(args.size());
std::transform( std::transform(
...@@ -165,4 +166,17 @@ inline shape compute_shape(operation op, std::vector<instruction_ref> args) ...@@ -165,4 +166,17 @@ inline shape compute_shape(operation op, std::vector<instruction_ref> args)
} // namespace migraph } // namespace migraph
namespace std {
template <>
struct hash<migraph::instruction_ref>
{
using argument_type = migraph::instruction_ref;
using result_type = std::size_t;
result_type operator()(const argument_type& x) const noexcept
{
return std::hash<migraph::instruction*>{}(&*x);
}
};
} // namespace std
#endif #endif
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_INSTRUCTION_REF_HPP #define MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#include <list> #include <list>
#include <functional>
namespace migraph { namespace migraph {
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include <migraph/tensor_view.hpp> #include <migraph/tensor_view.hpp>
#include <migraph/raw_data.hpp> #include <migraph/raw_data.hpp>
#include <memory>
namespace migraph { namespace migraph {
/** /**
...@@ -18,51 +20,57 @@ struct literal : raw_data<literal> ...@@ -18,51 +20,57 @@ struct literal : raw_data<literal>
literal() {} literal() {}
template <class T> template <class T>
literal(T x) : buffer(sizeof(T), 0), m_shape(shape::get_type<T>{}) literal(T x) : buffer(std::make_unique<char[]>(sizeof(T))), m_shape(shape::get_type<T>{})
{ {
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types"); static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
*(reinterpret_cast<T*>(buffer.data())) = x; *(reinterpret_cast<T*>(buffer.get())) = x;
} }
template <class T> template <class T>
literal(shape s, const std::vector<T>& x) : buffer(s.bytes(), 0), m_shape(s) literal(const shape& s, const std::vector<T>& x)
: buffer(std::make_unique<char[]>(s.bytes())), m_shape(s)
{ {
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types"); static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
fill(x.begin(), x.end()); fill(x.begin(), x.end());
} }
template <class T> template <class T>
literal(shape s, const std::initializer_list<T>& x) : buffer(s.bytes(), 0), m_shape(s) literal(const shape& s, const std::initializer_list<T>& x)
: buffer(std::make_unique<char[]>(s.bytes())), m_shape(s)
{ {
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types"); static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
fill(x.begin(), x.end()); fill(x.begin(), x.end());
} }
template <class Iterator> template <class Iterator>
literal(shape s, Iterator start, Iterator end) : buffer(s.bytes(), 0), m_shape(s) literal(const shape& s, Iterator start, Iterator end)
: buffer(std::make_unique<char[]>(s.bytes())), m_shape(s)
{ {
fill(start, end); fill(start, end);
} }
literal(shape s, const char* x) : buffer(x, x + s.bytes()), m_shape(s) {} literal(const shape& s, const char* x) : buffer(std::make_unique<char[]>(s.bytes())), m_shape(s)
{
std::copy(x, x + s.bytes(), buffer.get());
}
/// Whether data is available /// Whether data is available
bool empty() const { return this->buffer.empty(); } bool empty() const { return this->buffer == nullptr; }
/// Provides a raw pointer to the data /// Provides a raw pointer to the data
const char* data() const { return this->buffer.data(); } const char* data() const { return this->buffer.get(); }
const shape& get_shape() const { return this->m_shape; } const shape& get_shape() const { return this->m_shape; }
/// Convert the data to an argument /// Convert the data to an argument
argument get_argument() const argument get_argument() const
{ {
auto b = buffer; std::vector<char> b(buffer.get(), buffer.get() + m_shape.bytes());
return {m_shape, [b]() mutable { return b.data(); }}; return {m_shape, [b]() mutable { return b.data(); }};
} }
private: private:
std::vector<char> buffer; std::shared_ptr<char> buffer;
shape m_shape; shape m_shape;
template <class Iterator> template <class Iterator>
...@@ -70,13 +78,13 @@ struct literal : raw_data<literal> ...@@ -70,13 +78,13 @@ struct literal : raw_data<literal>
{ {
if(m_shape.standard()) if(m_shape.standard())
{ {
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.data())); }); m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
} }
else else
{ {
auto it = start; auto it = start;
m_shape.visit_type([&](auto as) { m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.data())); auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
it++; it++;
output(idx.begin(), idx.end()) = *it; output(idx.begin(), idx.end()) = *it;
......
...@@ -25,7 +25,7 @@ struct operation ...@@ -25,7 +25,7 @@ struct operation
/// This is used to compute the resulting shape from an operation. If an /// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an /// operation cannot be run with input shapes, then it should throw an
/// exception. /// exception.
shape compute_shape(std::vector<shape> input) const; shape compute_shape(const std::vector<shape>& input) const;
/** /**
* @brief This performs the operation's computation * @brief This performs the operation's computation
* *
...@@ -37,7 +37,7 @@ struct operation ...@@ -37,7 +37,7 @@ struct operation
* @return Return an `argument` of the result computation. The `shape` of `argument` should be * @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape. * the same the `output` shape.
*/ */
argument compute(context& ctx, shape output, std::vector<argument> input) const; argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
/// An optional stream operator to print the operation. When this is not /// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name. /// implemented, it will just print the operation's name.
friend std::ostream& operator<<(std::ostream& os, const operation& op); friend std::ostream& operator<<(std::ostream& os, const operation& op);
...@@ -56,7 +56,8 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -56,7 +56,8 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream } // namespace operation_stream
template <class T> template <class T>
argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<argument> input) argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), output_shape, input);
} }
...@@ -67,8 +68,8 @@ argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<ar ...@@ -67,8 +68,8 @@ argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<ar
* struct operation * struct operation
* { * {
* std::string name() const; * std::string name() const;
* shape compute_shape(std::vector<shape> input) const; * shape compute_shape(const std::vector<shape>& input) const;
* argument compute(context& ctx,shape output,std::vector<argument> input) const; * argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* }; * };
* *
...@@ -137,17 +138,16 @@ struct operation ...@@ -137,17 +138,16 @@ struct operation
return (*this).private_detail_te_get_handle().name(); return (*this).private_detail_te_get_handle().name();
} }
shape compute_shape(std::vector<shape> input) const shape compute_shape(const std::vector<shape>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute_shape(std::move(input)); return (*this).private_detail_te_get_handle().compute_shape(input);
} }
argument compute(context& ctx, shape output, std::vector<argument> input) const argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute( return (*this).private_detail_te_get_handle().compute(ctx, output, input);
ctx, std::move(output), std::move(input));
} }
friend std::ostream& operator<<(std::ostream& os, const operation& op) friend std::ostream& operator<<(std::ostream& os, const operation& op)
...@@ -163,10 +163,11 @@ struct operation ...@@ -163,10 +163,11 @@ struct operation
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0; virtual shape compute_shape(const std::vector<shape>& input) const = 0;
virtual argument compute(context& ctx, shape output, std::vector<argument> input) const = 0; virtual argument
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -199,16 +200,18 @@ struct operation ...@@ -199,16 +200,18 @@ struct operation
std::string name() const override { return private_detail_te_value.name(); } std::string name() const override { return private_detail_te_value.name(); }
shape compute_shape(std::vector<shape> input) const override shape compute_shape(const std::vector<shape>& input) const override
{ {
return private_detail_te_value.compute_shape(std::move(input)); return private_detail_te_value.compute_shape(input);
} }
argument compute(context& ctx, shape output, std::vector<argument> input) const override argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input) const override
{ {
return compute_op(private_detail_te_value, ctx, std::move(output), std::move(input)); return compute_op(private_detail_te_value, ctx, output, input);
} }
std::ostream& operator_shift_left(std::ostream& os) const override std::ostream& operator_shift_left(std::ostream& os) const override
......
...@@ -7,12 +7,13 @@ ...@@ -7,12 +7,13 @@
#include <migraph/stringutils.hpp> #include <migraph/stringutils.hpp>
#include <migraph/streamutils.hpp> #include <migraph/streamutils.hpp>
#include <cmath> #include <cmath>
#include <utility>
namespace migraph { namespace migraph {
struct not_computable struct not_computable
{ {
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
...@@ -41,7 +42,7 @@ struct batch_norm_inference ...@@ -41,7 +42,7 @@ struct batch_norm_inference
return inputs.front(); return inputs.front();
} }
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
...@@ -114,7 +115,7 @@ struct convolution ...@@ -114,7 +115,7 @@ struct convolution
} }
} }
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
...@@ -145,8 +146,8 @@ struct pooling ...@@ -145,8 +146,8 @@ struct pooling
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto t = input.type(); auto t = input.type();
// assert(lengths[0] < (input.lens()[2] + 2 * padding[0])); assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
// assert(lengths[1] < (input.lens()[3] + 2 * padding[1])); assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
return {t, return {t,
{ {
...@@ -175,7 +176,7 @@ struct pooling ...@@ -175,7 +176,7 @@ struct pooling
}}; }};
} }
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
...@@ -201,7 +202,7 @@ struct activation ...@@ -201,7 +202,7 @@ struct activation
return inputs.front(); return inputs.front();
} }
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
...@@ -244,7 +245,14 @@ struct transpose ...@@ -244,7 +245,14 @@ struct transpose
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {output_shape, std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
}
friend std::ostream& operator<<(std::ostream& os, const transpose& op)
{
os << op.name() << "[";
os << "dims={" << stream_range(op.dims) << "}";
os << "]";
return os;
} }
}; };
...@@ -262,7 +270,7 @@ struct contiguous ...@@ -262,7 +270,7 @@ struct contiguous
} }
return {t, lens}; return {t, lens};
} }
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
...@@ -309,13 +317,13 @@ struct reshape ...@@ -309,13 +317,13 @@ struct reshape
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {output_shape, std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
friend std::ostream& operator<<(std::ostream& os, const reshape& op) friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{ {
os << op.name() << "["; os << op.name() << "[";
os << "dims={" << stream_range(op.dims) << "}, "; os << "dims={" << stream_range(op.dims) << "}";
os << "]"; os << "]";
return os; return os;
} }
...@@ -339,7 +347,7 @@ struct gemm ...@@ -339,7 +347,7 @@ struct gemm
return {t, {a.lens()[0], b.lens()[1]}}; return {t, {a.lens()[0], b.lens()[1]}};
} }
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
...@@ -359,7 +367,7 @@ struct unary ...@@ -359,7 +367,7 @@ struct unary
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); return inputs.at(0);
} }
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
...@@ -439,26 +447,26 @@ struct flatten ...@@ -439,26 +447,26 @@ struct flatten
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
auto&& lens = inputs.front().lens(); auto&& lens = inputs.front().lens();
if(axis == 0) if(axis > lens.size())
{
return {inputs.at(0).type(), {1, inputs.at(0).elements()}};
}
else if(axis < lens.size())
{
auto x = std::accumulate(
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
auto y = std::accumulate(
lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}};
}
else
{ {
MIGRAPH_THROW("axis for flatten must be less than tensor rank"); MIGRAPH_THROW("axis for flatten must be less than tensor rank");
} }
auto x =
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
auto y =
std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {output_shape, std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
}
friend std::ostream& operator<<(std::ostream& os, const flatten& op)
{
os << op.name() << "[";
os << "axis=" << op.axis;
os << "]";
return os;
} }
}; };
struct broadcast struct broadcast
...@@ -491,7 +499,14 @@ struct broadcast ...@@ -491,7 +499,14 @@ struct broadcast
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {output_shape, std::move(args.at(1).data)}; return {std::move(output_shape), std::move(args.at(1).data)};
}
friend std::ostream& operator<<(std::ostream& os, const broadcast& op)
{
os << op.name() << "[";
os << "axis=" << op.axis;
os << "]";
return os;
} }
}; };
...@@ -503,7 +518,7 @@ struct binary ...@@ -503,7 +518,7 @@ struct binary
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
...@@ -533,12 +548,15 @@ struct outline ...@@ -533,12 +548,15 @@ struct outline
{ {
shape s; shape s;
std::string name() const { return "outline"; } std::string name() const { return "outline"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(0); check_shapes{inputs, *this}.has(0);
return s; return s;
} }
argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; } argument compute(context&, const shape&, const std::vector<argument>&) const
{
return {s, nullptr};
}
}; };
} // namespace migraph } // namespace migraph
......
...@@ -34,7 +34,7 @@ struct program ...@@ -34,7 +34,7 @@ struct program
{ {
return add_instruction(op, {args...}); return add_instruction(op, {args...});
} }
instruction_ref add_instruction(operation op, std::vector<instruction_ref> args); instruction_ref add_instruction(const operation& op, std::vector<instruction_ref> args);
template <class... Ts> template <class... Ts>
instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args) instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args)
...@@ -42,15 +42,16 @@ struct program ...@@ -42,15 +42,16 @@ struct program
return insert_instruction(ins, op, {args...}); return insert_instruction(ins, op, {args...});
} }
instruction_ref instruction_ref
insert_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args); insert_instruction(instruction_ref ins, const operation& op, std::vector<instruction_ref> args);
template <class... Ts> template <class... Ts>
instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args) instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args)
{ {
return replace_instruction(ins, op, {args...}); return replace_instruction(ins, op, {args...});
} }
instruction_ref instruction_ref replace_instruction(instruction_ref ins,
replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args); const operation& op,
std::vector<instruction_ref> args);
instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep); instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep);
...@@ -67,7 +68,7 @@ struct program ...@@ -67,7 +68,7 @@ struct program
instruction_ref add_literal(literal l); instruction_ref add_literal(literal l);
instruction_ref add_outline(shape s); instruction_ref add_outline(const shape& s);
instruction_ref add_parameter(std::string name, shape s); instruction_ref add_parameter(std::string name, shape s);
...@@ -79,6 +80,7 @@ struct program ...@@ -79,6 +80,7 @@ struct program
bool has_instruction(instruction_ref ins) const; bool has_instruction(instruction_ref ins) const;
std::size_t size() const;
instruction_ref begin() const; instruction_ref begin() const;
instruction_ref end() const; instruction_ref end() const;
...@@ -88,6 +90,8 @@ struct program ...@@ -88,6 +90,8 @@ struct program
void compile(const target& t); void compile(const target& t);
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
friend std::ostream& operator<<(std::ostream& os, const program& p); friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
......
...@@ -5,11 +5,14 @@ ...@@ -5,11 +5,14 @@
#include <cassert> #include <cassert>
#include <ostream> #include <ostream>
#include <numeric> #include <numeric>
#include <memory>
#include <migraph/errors.hpp> #include <migraph/errors.hpp>
namespace migraph { namespace migraph {
struct shape_impl;
struct shape struct shape
{ {
...@@ -136,7 +139,7 @@ struct shape ...@@ -136,7 +139,7 @@ struct shape
template <class Visitor> template <class Visitor>
void visit_type(Visitor v) const void visit_type(Visitor v) const
{ {
switch(this->m_type) switch(this->type())
{ {
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \ #define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return; case x: v(as<t>()); return;
...@@ -147,12 +150,8 @@ struct shape ...@@ -147,12 +150,8 @@ struct shape
} }
private: private:
type_t m_type; std::shared_ptr<const shape_impl> impl;
std::vector<std::size_t> m_lens;
std::vector<std::size_t> m_strides;
bool m_standard;
void calculate_strides();
std::size_t element_space() const; std::size_t element_space() const;
std::string type_string() const; std::string type_string() const;
}; };
......
...@@ -29,7 +29,7 @@ inline bool ends_with(const std::string& value, const std::string& suffix) ...@@ -29,7 +29,7 @@ inline bool ends_with(const std::string& value, const std::string& suffix)
} }
template <class Strings> template <class Strings>
inline std::string join_strings(Strings strings, std::string delim) inline std::string join_strings(Strings strings, const std::string& delim)
{ {
auto it = strings.begin(); auto it = strings.begin();
if(it == strings.end()) if(it == strings.end())
...@@ -57,7 +57,7 @@ inline bool starts_with(const std::string& value, const std::string& prefix) ...@@ -57,7 +57,7 @@ inline bool starts_with(const std::string& value, const std::string& prefix)
return std::equal(prefix.begin(), prefix.end(), value.begin()); return std::equal(prefix.begin(), prefix.end(), value.begin());
} }
inline std::string remove_prefix(std::string s, std::string prefix) inline std::string remove_prefix(std::string s, const std::string& prefix)
{ {
if(starts_with(s, prefix)) if(starts_with(s, prefix))
return s.substr(prefix.length()); return s.substr(prefix.length());
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraph/requires.hpp> #include <migraph/requires.hpp>
#include <iostream> #include <iostream>
#include <utility>
namespace migraph { namespace migraph {
...@@ -14,7 +15,7 @@ struct tensor_view ...@@ -14,7 +15,7 @@ struct tensor_view
{ {
using value_type = T; using value_type = T;
tensor_view() : m_data(nullptr) {} tensor_view() : m_data(nullptr) {}
tensor_view(shape s, T* d) : m_data(d), m_shape(s) {} tensor_view(shape s, T* d) : m_data(d), m_shape(std::move(s)) {}
const shape& get_shape() const { return this->m_shape; } const shape& get_shape() const { return this->m_shape; }
......
#ifndef MIGRAPH_GUARD_RTGLIB_TIME_HPP
#define MIGRAPH_GUARD_RTGLIB_TIME_HPP
#include <chrono>
namespace migraph {
template <class Duration, class F>
auto time(F f)
{
auto start = std::chrono::steady_clock::now();
f();
auto finish = std::chrono::steady_clock::now();
return std::chrono::duration_cast<Duration>(finish - start).count();
}
} // namespace migraph
#endif
...@@ -16,6 +16,7 @@ add_executable(read_onnx read_onnx.cpp) ...@@ -16,6 +16,7 @@ add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx) rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx migraph_onnx) target_link_libraries(read_onnx migraph_onnx)
add_executable(mnist mnist.cpp) add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist) rocm_clang_tidy_check(mnist)
target_link_libraries(mnist migraph_cpu migraph_gpu migraph_onnx) target_link_libraries(mnist migraph_cpu migraph_gpu migraph_onnx)
...@@ -28,4 +29,8 @@ if(MIGRAPH_ENABLE_GPU) ...@@ -28,4 +29,8 @@ if(MIGRAPH_ENABLE_GPU)
add_executable(verify_onnx verify_onnx.cpp) add_executable(verify_onnx verify_onnx.cpp)
rocm_clang_tidy_check(verify_onnx) rocm_clang_tidy_check(verify_onnx)
target_link_libraries(verify_onnx migraph_onnx migraph_cpu migraph_gpu) target_link_libraries(verify_onnx migraph_onnx migraph_cpu migraph_gpu)
add_executable(perf_onnx perf_onnx.cpp)
rocm_clang_tidy_check(perf_onnx)
target_link_libraries(perf_onnx migraph_onnx migraph_cpu migraph_gpu)
endif() endif()
...@@ -21,7 +21,8 @@ auto reverse_int(unsigned int i) ...@@ -21,7 +21,8 @@ auto reverse_int(unsigned int i)
(static_cast<unsigned int>(c3) << 8u) + c4; (static_cast<unsigned int>(c3) << 8u) + c4;
}; };
std::vector<float> read_mnist_images(std::string full_path, int& number_of_images, int& image_size) std::vector<float>
read_mnist_images(const std::string& full_path, int& number_of_images, int& image_size)
{ {
using uchar = unsigned char; using uchar = unsigned char;
...@@ -64,7 +65,7 @@ std::vector<float> read_mnist_images(std::string full_path, int& number_of_image ...@@ -64,7 +65,7 @@ std::vector<float> read_mnist_images(std::string full_path, int& number_of_image
} }
} }
std::vector<int32_t> read_mnist_labels(std::string full_path, int& number_of_labels) std::vector<int32_t> read_mnist_labels(const std::string& full_path, int& number_of_labels)
{ {
using uchar = unsigned char; using uchar = unsigned char;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment