Commit 06fb0905 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added MNIST test for cpu target

parents 0a59f103 cff16121
#ifndef RTG_GUARD_INSTRUCTION_REF_HPP
#define RTG_GUARD_INSTRUCTION_REF_HPP
#ifndef MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#define MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#include <list>
namespace rtg {
namespace migraph {
struct instruction;
using instruction_ref = std::list<instruction>::iterator;
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_LITERAL_HPP
#define RTG_GUARD_RTGLIB_LITERAL_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#include <rtg/shape.hpp>
#include <rtg/argument.hpp>
#include <rtg/tensor_view.hpp>
#include <rtg/raw_data.hpp>
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/tensor_view.hpp>
#include <migraph/raw_data.hpp>
namespace rtg {
namespace migraph {
/**
* @brief Represents a raw literal
......@@ -68,6 +68,6 @@ struct literal : raw_data<literal>
shape m_shape;
};
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTG_MANAGE_PTR_HPP
#define RTG_GUARD_RTG_MANAGE_PTR_HPP
#ifndef MIGRAPH_GUARD_MIGRAPH_MANAGE_PTR_HPP
#define MIGRAPH_GUARD_MIGRAPH_MANAGE_PTR_HPP
#include <memory>
#include <type_traits>
namespace rtg {
namespace migraph {
template <class F, F f> // NOLINT
struct manage_deleter
......@@ -49,8 +49,9 @@ shared<T> share(T p)
return shared<T>{std::move(p)};
}
} // namespace rtg
} // namespace migraph
#define RTG_MANAGE_PTR(T, F) rtg::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
#define MIGRAPH_MANAGE_PTR(T, F) \
migraph::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
#endif
#ifndef GUARD_MIGRAPHLIB_ONNX_HPP
#define GUARD_MIGRAPHLIB_ONNX_HPP
#include <migraph/program.hpp>
namespace migraph {
program parse_onnx(const std::string& name);
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_OPERAND_HPP
#define RTG_GUARD_RTGLIB_OPERAND_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <rtg/shape.hpp>
#include <rtg/argument.hpp>
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
namespace rtg {
namespace migraph {
namespace operation_stream {
......@@ -28,7 +29,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(shape output,std::vector<argument> input) const;
* argument compute(context& ctx,shape output,std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
*
......@@ -83,6 +84,14 @@ struct operation
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -95,10 +104,11 @@ struct operation
return (*this).private_detail_te_get_handle().compute_shape(std::move(input));
}
argument compute(shape output, std::vector<argument> input) const
argument compute(context& ctx, shape output, std::vector<argument> input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(std::move(output), std::move(input));
return (*this).private_detail_te_get_handle().compute(
ctx, std::move(output), std::move(input));
}
friend std::ostream& operator<<(std::ostream& os, const operation& op)
......@@ -114,10 +124,10 @@ struct operation
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(shape output, std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(context& ctx, shape output, std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -156,15 +166,15 @@ struct operation
return private_detail_te_value.compute_shape(std::move(input));
}
argument compute(shape output, std::vector<argument> input) const override
argument compute(context& ctx, shape output, std::vector<argument> input) const override
{
return private_detail_te_value.compute(std::move(output), std::move(input));
return private_detail_te_value.compute(ctx, std::move(output), std::move(input));
}
std::ostream& operator_shift_left(std::ostream& os) const override
{
using rtg::operation_stream::operator<<;
using migraph::operation_stream::operator<<;
return os << private_detail_te_value;
}
......@@ -181,13 +191,20 @@ struct operation
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
......@@ -226,6 +243,6 @@ inline const ValueType& any_cast(const operation& x)
return *y;
}
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_OPERATORS_HPP
#define RTG_GUARD_OPERATORS_HPP
#ifndef MIGRAPH_GUARD_OPERATORS_HPP
#define MIGRAPH_GUARD_OPERATORS_HPP
#include <array>
#include <rtg/operation.hpp>
#include <rtg/stringutils.hpp>
#include <rtg/streamutils.hpp>
#include <migraph/operation.hpp>
#include <migraph/stringutils.hpp>
#include <migraph/streamutils.hpp>
#include <cmath>
namespace rtg {
namespace migraph {
struct check_shapes
{
const std::vector<shape>* shapes;
const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {}
template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name())
{
}
std::string prefix() const
{
if(name.empty())
return "";
else
return name + ": ";
}
const check_shapes& has(std::size_t n) const
{
assert(shapes != nullptr);
if(shapes->size() != n)
RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " +
std::to_string(shapes->size()));
MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
" but given " + std::to_string(shapes->size()));
return *this;
}
......@@ -30,7 +44,7 @@ struct check_shapes
if(!shapes->empty())
{
if(shapes->front().lens().size() != n)
RTG_THROW("Only " + std::to_string(n) + "d supported");
MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
}
return *this;
}
......@@ -38,28 +52,28 @@ struct check_shapes
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
RTG_THROW("Shapes do not match");
MIGRAPH_THROW(prefix() + "Shapes do not match");
return *this;
}
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
RTG_THROW("Types do not match");
MIGRAPH_THROW(prefix() + "Types do not match");
return *this;
}
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
RTG_THROW("Dimensions do not match");
MIGRAPH_THROW(prefix() + "Dimensions do not match");
return *this;
}
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.lens().size(); }))
RTG_THROW("Dimensions do not match");
MIGRAPH_THROW(prefix() + "Dimensions do not match");
return *this;
}
......@@ -83,7 +97,10 @@ struct check_shapes
struct not_computable
{
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
};
struct convolution
......@@ -101,7 +118,7 @@ struct convolution
std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_ndims().only_dims(4);
check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
......@@ -149,11 +166,14 @@ struct convolution
}
else
{
RTG_THROW("Invalid padding mode");
MIGRAPH_THROW("Invalid padding mode");
}
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{
......@@ -175,7 +195,7 @@ struct pooling
std::string name() const { return "pooling"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).only_dims(4);
check_shapes{inputs, *this}.has(1).only_dims(4);
const shape& input = inputs.at(0);
auto t = input.type();
......@@ -197,7 +217,10 @@ struct pooling
}};
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{
......@@ -216,11 +239,14 @@ struct activation
std::string name() const { return "activation"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const activation& op)
{
os << op.name() << ":" << op.mode;
......@@ -234,20 +260,20 @@ struct transpose
std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0);
auto input_lens = input.lens();
auto input_strides = input.strides();
auto t = input.type();
if(dims.size() != input_lens.size())
{
RTG_THROW("Permutation has wrong number of axes");
MIGRAPH_THROW("Permutation has wrong number of axes");
}
std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0);
if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
{
RTG_THROW("Invalid permutation");
MIGRAPH_THROW("Invalid permutation");
}
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size());
......@@ -258,7 +284,10 @@ struct transpose
}
return {t, output_lens, output_strides};
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
};
struct contiguous
......@@ -266,16 +295,19 @@ struct contiguous
std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
if(lens.size() < 2)
{
RTG_THROW("Number of dimensions should exceed 1");
MIGRAPH_THROW("Number of dimensions should exceed 1");
}
return {t, lens};
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
};
struct reshape
......@@ -284,8 +316,7 @@ struct reshape
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty())
RTG_THROW("Wrong number of arguments");
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++)
......@@ -299,11 +330,15 @@ struct reshape
std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
}
shape s{inputs.front().type(), rdims};
if (s.elements() != inputs.front().elements()) RTG_THROW("Wrong number of elements");
if(s.elements() != inputs.front().elements())
MIGRAPH_THROW("Wrong number of elements for reshape");
return s;
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{
......@@ -319,17 +354,20 @@ struct gemm
std::string name() const { return "gemm"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_ndims().only_dims(2);
check_shapes{inputs, *this}.has(2).same_type();
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if(a.lens()[1] != b.lens()[0])
RTG_THROW("Inner dimensions do not match");
MIGRAPH_THROW("Inner dimensions do not match");
return {t, {a.lens()[0], b.lens()[1]}};
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{
......@@ -346,7 +384,10 @@ struct unary
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
};
struct identity : unary
......@@ -435,19 +476,19 @@ struct broadcast
result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; }))
{
if(axis != 0)
RTG_THROW("when broadcasting tensor of size 1, axis should be 0");
MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0");
return {t, result.lens(), std::move(bcast_strides)};
}
else
{
assert(result.lens().size() - axis >= input.lens().size());
if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin() + axis))
RTG_THROW("when broadcasting success sizes must match");
MIGRAPH_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, result.lens(), std::move(bcast_strides)};
}
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.at(1).data)};
}
......@@ -461,7 +502,10 @@ struct binary
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
};
struct add : binary
......@@ -490,12 +534,26 @@ struct outline
std::string name() const { return "outline"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(0);
check_shapes{inputs, *this}.has(0);
return s;
}
argument compute(shape, std::vector<argument>) const { return {s, nullptr}; }
argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; }
};
template <class T>
struct check_context
{
std::string name() const { return "check_context"; }
shape compute_shape(std::vector<shape>) const { return {}; }
argument compute(context& ctx, shape, std::vector<argument>) const
{
T* x = any_cast<T>(&ctx);
if(x == nullptr)
MIGRAPH_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
return {};
}
};
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_PROGRAM_HPP
#define RTG_GUARD_RTGLIB_PROGRAM_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_PROGRAM_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_PROGRAM_HPP
#include <list>
#include <unordered_map>
#include <rtg/operation.hpp>
#include <rtg/literal.hpp>
#include <rtg/builtin.hpp>
#include <rtg/instruction_ref.hpp>
#include <rtg/target.hpp>
#include <migraph/operation.hpp>
#include <migraph/literal.hpp>
#include <migraph/builtin.hpp>
#include <migraph/instruction_ref.hpp>
#include <migraph/target.hpp>
#include <algorithm>
#include <iostream>
namespace rtg {
namespace migraph {
struct program_impl;
......@@ -27,6 +27,8 @@ struct program
program& operator=(program&&) noexcept;
~program() noexcept;
using parameter_map = std::unordered_map<std::string, argument>;
template <class... Ts>
instruction_ref add_instruction(operation op, Ts... args)
{
......@@ -64,7 +66,7 @@ struct program
shape get_parameter_shape(std::string name);
argument eval(std::unordered_map<std::string, argument> params) const;
argument eval(parameter_map params) const;
friend std::ostream& operator<<(std::ostream& os, const program& p);
......@@ -81,6 +83,6 @@ struct program
std::unique_ptr<program_impl> impl;
};
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_RANGES_HPP
#define RTG_GUARD_RTGLIB_RANGES_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP
namespace rtg {
namespace migraph {
template <class C, class T>
bool contains(C&& c, T&& x)
......@@ -15,6 +15,6 @@ void copy(Range&& r, Iterator it)
std::copy(r.begin(), r.end(), it);
}
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_RAW_DATA_HPP
#define RTG_GUARD_RAW_DATA_HPP
#ifndef MIGRAPH_GUARD_RAW_DATA_HPP
#define MIGRAPH_GUARD_RAW_DATA_HPP
#include <rtg/tensor_view.hpp>
#include <rtg/requires.hpp>
#include <migraph/tensor_view.hpp>
#include <migraph/requires.hpp>
namespace rtg {
namespace migraph {
struct raw_data_base
{
......@@ -89,13 +89,27 @@ struct raw_data : raw_data_base
assert(self->single());
return self->template at<T>();
}
template <class T>
using is_data_ptr =
bool_c<(std::is_void<T>{} or std::is_same<char, std::remove_cv_t<T>>{} or
std::is_same<unsigned char, std::remove_cv_t<T>>{})>;
template <class T>
using get_data_type = std::conditional_t<is_data_ptr<T>{}, float, T>;
template <class T>
bool matches() const
{
return is_data_ptr<T>{} ||
self->get_shape().type() == migraph::shape::get_type<get_data_type<T>>{};
}
template <class T>
operator T*()
{
using type = std::remove_cv_t<T>;
assert((std::is_void<T>{} or std::is_same<char, type>{} or
std::is_same<unsigned char, type>{} or
self->get_shape().type() == rtg::shape::get_type<T>{}));
assert(matches<T>());
return reinterpret_cast<type*>(self->data());
}
};
......@@ -109,15 +123,26 @@ struct raw_data : raw_data_base
{
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
if(s.type() != rtg::shape::get_type<T>{})
RTG_THROW("Incorrect data type for raw data");
if(s.type() != migraph::shape::get_type<T>{})
MIGRAPH_THROW("Incorrect data type for raw data");
return make_view(s, reinterpret_cast<T*>(buffer));
}
/// Cast the data pointer
template <class T>
T* cast() const
{
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
assert(s.type() == migraph::shape::get_type<T>{});
return reinterpret_cast<T*>(buffer);
}
};
template <class T,
class U,
RTG_REQUIRES(std::is_base_of<raw_data_base, T>{} && std::is_base_of<raw_data_base, U>{})>
MIGRAPH_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{
auto&& xshape = x.get_shape();
......@@ -139,7 +164,8 @@ bool operator==(const T& x, const U& y)
template <class T,
class U,
RTG_REQUIRES(std::is_base_of<raw_data_base, T>{} && std::is_base_of<raw_data_base, U>{})>
MIGRAPH_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{
return !(x == y);
......@@ -170,13 +196,13 @@ auto visit_all(T&& x, Ts&&... xs)
auto&& s = x.get_shape();
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
RTG_THROW("Types must be the same");
MIGRAPH_THROW("Types must be the same");
return [&](auto v) {
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
detail::visit_all_impl(s, v, x, xs...);
};
}
} // namespace rtg
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#include <type_traits>
namespace migraph {
template <bool... Bs>
struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
{
};
template <bool B>
using bool_c = std::integral_constant<bool, B>;
#ifdef CPPCHECK
#define MIGRAPH_REQUIRES(...) class = void
#else
#define MIGRAPH_REQUIRES(...) \
bool PrivateRequires##__LINE__ = true, \
class = typename std::enable_if<and_<__VA_ARGS__, PrivateRequires##__LINE__>{}>::type
#endif
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_SHAPE_HPP
#define RTG_GUARD_RTGLIB_SHAPE_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#include <vector>
#include <cassert>
#include <ostream>
#include <numeric>
#include <rtg/errors.hpp>
#include <migraph/errors.hpp>
namespace rtg {
namespace migraph {
struct shape
{
// Add new types here
// clang-format off
#define RTG_SHAPE_VISIT_TYPES(m) \
#define MIGRAPH_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \
m(double_type, double) \
m(uint8_type, uint8_t) \
......@@ -28,25 +28,22 @@ struct shape
m(uint64_type, uint64_t)
// clang-format on
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
#define MIGRAPH_SHAPE_ENUM_TYPES(x, t) x,
enum type_t
{
any_type,
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES)
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_ENUM_TYPES)
};
#undef RTG_SHAPE_ENUM_TYPES
#undef MIGRAPH_SHAPE_ENUM_TYPES
template <class T, class = void>
struct get_type : std::integral_constant<type_t, any_type>
{
};
#define RTG_SHAPE_GET_TYPE(x, t) \
struct get_type;
#define MIGRAPH_SHAPE_GET_TYPE(x, t) \
template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \
};
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_GET_TYPE)
#undef RTG_SHAPE_GET_TYPE
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE)
#undef MIGRAPH_SHAPE_GET_TYPE
shape();
shape(type_t t);
......@@ -74,6 +71,7 @@ struct shape
std::size_t index(std::size_t i) const;
bool packed() const;
bool broadcasted() const;
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
......@@ -124,13 +122,12 @@ struct shape
{
switch(this->m_type)
{
case any_type: RTG_THROW("Cannot visit the any_type");
#define RTG_SHAPE_VISITOR_CASE(x, t) \
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE)
#undef RTG_SHAPE_VISITOR_CASE
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_VISITOR_CASE)
#undef MIGRAPH_SHAPE_VISITOR_CASE
}
RTG_THROW("Unknown type");
MIGRAPH_THROW("Unknown type");
}
private:
......@@ -144,6 +141,6 @@ struct shape
std::string type_string() const;
};
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_SHAPE_FOR_EACH_HPP
#define RTG_GUARD_RTGLIB_SHAPE_FOR_EACH_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#include <rtg/shape.hpp>
#include <migraph/shape.hpp>
#include <algorithm>
namespace rtg {
namespace migraph {
template <class F>
void shape_for_each(const rtg::shape& s, F f)
void shape_for_each(const migraph::shape& s, F f)
{
// Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
......@@ -26,6 +26,6 @@ void shape_for_each(const rtg::shape& s, F f)
}
}
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_STREAMUTILS_HPP
#define RTG_GUARD_STREAMUTILS_HPP
#ifndef MIGRAPH_GUARD_STREAMUTILS_HPP
#define MIGRAPH_GUARD_STREAMUTILS_HPP
#include <ostream>
#include <algorithm>
namespace rtg {
namespace migraph {
template <class T>
struct stream_range_container
......@@ -31,6 +31,6 @@ inline stream_range_container<Range> stream_range(const Range& r)
return {r};
}
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_STRINGUTILS_HPP
#define RTG_GUARD_RTGLIB_STRINGUTILS_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#include <algorithm>
#include <numeric>
#include <string>
#include <sstream>
namespace rtg {
namespace migraph {
inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace)
......@@ -77,6 +77,6 @@ inline std::string to_string(const Range& r)
return ss.str();
}
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_TARGET_HPP
#define RTG_GUARD_RTGLIB_TARGET_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraph/context.hpp>
namespace rtg {
namespace migraph {
struct program;
......@@ -18,6 +19,7 @@ struct program;
* {
* std::string name() const;
* void apply(program & p) const;
* context get_context() const;
* };
*
*/
......@@ -71,6 +73,14 @@ struct target
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -83,6 +93,12 @@ struct target
return (*this).private_detail_te_get_handle().apply(p);
}
context get_context() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_context();
}
private:
struct private_detail_te_handle_base_type
{
......@@ -92,6 +108,7 @@ struct target
virtual std::string name() const = 0;
virtual void apply(program& p) const = 0;
virtual context get_context() const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -126,6 +143,8 @@ struct target
void apply(program& p) const override { return private_detail_te_value.apply(p); }
context get_context() const override { return private_detail_te_value.get_context(); }
PrivateDetailTypeErasedT private_detail_te_value;
};
......@@ -139,13 +158,20 @@ struct target
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
......@@ -184,6 +210,6 @@ inline const ValueType& any_cast(const target& x)
return *y;
}
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_TENSOR_VIEW_HPP
#define RTG_GUARD_TENSOR_VIEW_HPP
#ifndef MIGRAPH_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH_GUARD_TENSOR_VIEW_HPP
#include <rtg/shape.hpp>
#include <rtg/float_equal.hpp>
#include <rtg/requires.hpp>
#include <migraph/shape.hpp>
#include <migraph/float_equal.hpp>
#include <migraph/requires.hpp>
#include <iostream>
namespace rtg {
namespace migraph {
template <class T>
struct tensor_view
......@@ -26,27 +26,27 @@ struct tensor_view
const T* data() const { return this->m_data; }
template <class... Ts, RTG_REQUIRES(std::is_integral<Ts>{}...)>
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const
{
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
}
template <class... Ts, RTG_REQUIRES(std::is_integral<Ts>{}...)>
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs)
{
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
}
template <class Iterator, RTG_REQUIRES(not std::is_integral<Iterator>{})>
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})>
const T& operator()(Iterator start, Iterator last) const
{
return m_data[m_shape.index(start, last)];
}
template <class Iterator, RTG_REQUIRES(not std::is_integral<Iterator>{})>
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})>
T& operator()(Iterator start, Iterator last)
{
return m_data[m_shape.index(start, last)];
......@@ -164,6 +164,6 @@ tensor_view<T> make_view(shape s, T* data)
return {s, data};
}
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_FALLTHROUGH_HPP
#define RTG_GUARD_FALLTHROUGH_HPP
namespace rtg {
#ifdef __clang__
#define RTG_FALLTHROUGH [[clang::fallthrough]]
#else
#define RTG_FALLTHROUGH
#endif
} // namespace rtg
#endif
#ifndef GUARD_RTGLIB_ONNX_HPP
#define GUARD_RTGLIB_ONNX_HPP
#include <rtg/program.hpp>
namespace rtg {
program parse_onnx(const std::string& name);
} // namespace rtg
#endif
#ifndef RTG_GUARD_RTGLIB_REQUIRES_HPP
#define RTG_GUARD_RTGLIB_REQUIRES_HPP
#include <type_traits>
namespace rtg {
template <bool... Bs>
struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
{
};
#ifdef CPPCHECK
#define RTG_REQUIRES(...) class = void
#else
#define RTG_REQUIRES(...) class = typename std::enable_if<and_<__VA_ARGS__, true>{}>::type
#endif
} // namespace rtg
#endif
......@@ -5,15 +5,23 @@ add_library(onnx-proto STATIC ${PROTO_SRCS})
target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR})
target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(rtg_onnx onnx.cpp)
rocm_clang_tidy_check(rtg_onnx)
target_link_libraries(rtg_onnx onnx-proto rtg)
add_library(migraph_onnx onnx.cpp)
rocm_clang_tidy_check(migraph_onnx)
target_link_libraries(migraph_onnx PRIVATE onnx-proto)
target_link_libraries(migraph_onnx PUBLIC migraph)
add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx rtg_onnx rtg_cpu)
target_link_libraries(read_onnx migraph_onnx)
add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist)
target_link_libraries(mnist rtg_onnx rtg_cpu)
target_link_libraries(mnist migraph_cpu migraph_onnx)
if(MIGRAPH_ENABLE_MIOPEN)
add_executable(verify_onnx verify_onnx.cpp)
rocm_clang_tidy_check(verify_onnx)
target_link_libraries(verify_onnx migraph_onnx migraph_cpu migraph_miopen)
endif()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment