Commit 06fb0905 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added MNIST test for cpu target

parents 0a59f103 cff16121
#ifndef RTG_GUARD_INSTRUCTION_REF_HPP #ifndef MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#define RTG_GUARD_INSTRUCTION_REF_HPP #define MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#include <list> #include <list>
namespace rtg { namespace migraph {
struct instruction; struct instruction;
using instruction_ref = std::list<instruction>::iterator; using instruction_ref = std::list<instruction>::iterator;
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_RTGLIB_LITERAL_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#define RTG_GUARD_RTGLIB_LITERAL_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#include <rtg/shape.hpp> #include <migraph/shape.hpp>
#include <rtg/argument.hpp> #include <migraph/argument.hpp>
#include <rtg/tensor_view.hpp> #include <migraph/tensor_view.hpp>
#include <rtg/raw_data.hpp> #include <migraph/raw_data.hpp>
namespace rtg { namespace migraph {
/** /**
* @brief Represents a raw literal * @brief Represents a raw literal
...@@ -68,6 +68,6 @@ struct literal : raw_data<literal> ...@@ -68,6 +68,6 @@ struct literal : raw_data<literal>
shape m_shape; shape m_shape;
}; };
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_RTG_MANAGE_PTR_HPP #ifndef MIGRAPH_GUARD_MIGRAPH_MANAGE_PTR_HPP
#define RTG_GUARD_RTG_MANAGE_PTR_HPP #define MIGRAPH_GUARD_MIGRAPH_MANAGE_PTR_HPP
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
namespace rtg { namespace migraph {
template <class F, F f> // NOLINT template <class F, F f> // NOLINT
struct manage_deleter struct manage_deleter
...@@ -49,8 +49,9 @@ shared<T> share(T p) ...@@ -49,8 +49,9 @@ shared<T> share(T p)
return shared<T>{std::move(p)}; return shared<T>{std::move(p)};
} }
} // namespace rtg } // namespace migraph
#define RTG_MANAGE_PTR(T, F) rtg::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT #define MIGRAPH_MANAGE_PTR(T, F) \
migraph::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
#endif #endif
#ifndef GUARD_MIGRAPHLIB_ONNX_HPP
#define GUARD_MIGRAPHLIB_ONNX_HPP
#include <migraph/program.hpp>
namespace migraph {
program parse_onnx(const std::string& name);
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_OPERAND_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#define RTG_GUARD_RTGLIB_OPERAND_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#include <string> #include <string>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <rtg/shape.hpp> #include <migraph/shape.hpp>
#include <rtg/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/context.hpp>
namespace rtg { namespace migraph {
namespace operation_stream { namespace operation_stream {
...@@ -28,7 +29,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -28,7 +29,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
* { * {
* std::string name() const; * std::string name() const;
* shape compute_shape(std::vector<shape> input) const; * shape compute_shape(std::vector<shape> input) const;
* argument compute(shape output,std::vector<argument> input) const; * argument compute(context& ctx,shape output,std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ; * friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* }; * };
* *
...@@ -83,6 +84,14 @@ struct operation ...@@ -83,6 +84,14 @@ struct operation
: nullptr; : nullptr;
} }
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const std::string name() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -95,10 +104,11 @@ struct operation ...@@ -95,10 +104,11 @@ struct operation
return (*this).private_detail_te_get_handle().compute_shape(std::move(input)); return (*this).private_detail_te_get_handle().compute_shape(std::move(input));
} }
argument compute(shape output, std::vector<argument> input) const argument compute(context& ctx, shape output, std::vector<argument> input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(std::move(output), std::move(input)); return (*this).private_detail_te_get_handle().compute(
ctx, std::move(output), std::move(input));
} }
friend std::ostream& operator<<(std::ostream& os, const operation& op) friend std::ostream& operator<<(std::ostream& os, const operation& op)
...@@ -114,10 +124,10 @@ struct operation ...@@ -114,10 +124,10 @@ struct operation
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0; virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(shape output, std::vector<argument> input) const = 0; virtual argument compute(context& ctx, shape output, std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -156,15 +166,15 @@ struct operation ...@@ -156,15 +166,15 @@ struct operation
return private_detail_te_value.compute_shape(std::move(input)); return private_detail_te_value.compute_shape(std::move(input));
} }
argument compute(shape output, std::vector<argument> input) const override argument compute(context& ctx, shape output, std::vector<argument> input) const override
{ {
return private_detail_te_value.compute(std::move(output), std::move(input)); return private_detail_te_value.compute(ctx, std::move(output), std::move(input));
} }
std::ostream& operator_shift_left(std::ostream& os) const override std::ostream& operator_shift_left(std::ostream& os) const override
{ {
using rtg::operation_stream::operator<<; using migraph::operation_stream::operator<<;
return os << private_detail_te_value; return os << private_detail_te_value;
} }
...@@ -181,13 +191,20 @@ struct operation ...@@ -181,13 +191,20 @@ struct operation
} }
}; };
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{ {
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
...@@ -226,6 +243,6 @@ inline const ValueType& any_cast(const operation& x) ...@@ -226,6 +243,6 @@ inline const ValueType& any_cast(const operation& x)
return *y; return *y;
} }
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_OPERATORS_HPP #ifndef MIGRAPH_GUARD_OPERATORS_HPP
#define RTG_GUARD_OPERATORS_HPP #define MIGRAPH_GUARD_OPERATORS_HPP
#include <array> #include <array>
#include <rtg/operation.hpp> #include <migraph/operation.hpp>
#include <rtg/stringutils.hpp> #include <migraph/stringutils.hpp>
#include <rtg/streamutils.hpp> #include <migraph/streamutils.hpp>
#include <cmath> #include <cmath>
namespace rtg { namespace migraph {
struct check_shapes struct check_shapes
{ {
const std::vector<shape>* shapes; const std::vector<shape>* shapes;
const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {} check_shapes(const std::vector<shape>& s) : shapes(&s) {}
template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name())
{
}
std::string prefix() const
{
if(name.empty())
return "";
else
return name + ": ";
}
const check_shapes& has(std::size_t n) const const check_shapes& has(std::size_t n) const
{ {
assert(shapes != nullptr); assert(shapes != nullptr);
if(shapes->size() != n) if(shapes->size() != n)
RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " + MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
std::to_string(shapes->size())); " but given " + std::to_string(shapes->size()));
return *this; return *this;
} }
...@@ -30,7 +44,7 @@ struct check_shapes ...@@ -30,7 +44,7 @@ struct check_shapes
if(!shapes->empty()) if(!shapes->empty())
{ {
if(shapes->front().lens().size() != n) if(shapes->front().lens().size() != n)
RTG_THROW("Only " + std::to_string(n) + "d supported"); MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
} }
return *this; return *this;
} }
...@@ -38,28 +52,28 @@ struct check_shapes ...@@ -38,28 +52,28 @@ struct check_shapes
const check_shapes& same_shape() const const check_shapes& same_shape() const
{ {
if(!this->same([](const shape& s) { return s; })) if(!this->same([](const shape& s) { return s; }))
RTG_THROW("Shapes do not match"); MIGRAPH_THROW(prefix() + "Shapes do not match");
return *this; return *this;
} }
const check_shapes& same_type() const const check_shapes& same_type() const
{ {
if(!this->same([](const shape& s) { return s.type(); })) if(!this->same([](const shape& s) { return s.type(); }))
RTG_THROW("Types do not match"); MIGRAPH_THROW(prefix() + "Types do not match");
return *this; return *this;
} }
const check_shapes& same_dims() const const check_shapes& same_dims() const
{ {
if(!this->same([](const shape& s) { return s.lens(); })) if(!this->same([](const shape& s) { return s.lens(); }))
RTG_THROW("Dimensions do not match"); MIGRAPH_THROW(prefix() + "Dimensions do not match");
return *this; return *this;
} }
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(!this->same([](const shape& s) { return s.lens().size(); })) if(!this->same([](const shape& s) { return s.lens().size(); }))
RTG_THROW("Dimensions do not match"); MIGRAPH_THROW(prefix() + "Dimensions do not match");
return *this; return *this;
} }
...@@ -83,7 +97,10 @@ struct check_shapes ...@@ -83,7 +97,10 @@ struct check_shapes
struct not_computable struct not_computable
{ {
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct convolution struct convolution
...@@ -101,7 +118,7 @@ struct convolution ...@@ -101,7 +118,7 @@ struct convolution
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_ndims().only_dims(4); check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
...@@ -149,11 +166,14 @@ struct convolution ...@@ -149,11 +166,14 @@ struct convolution
} }
else else
{ {
RTG_THROW("Invalid padding mode"); MIGRAPH_THROW("Invalid padding mode");
} }
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const convolution& op) friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{ {
...@@ -175,7 +195,7 @@ struct pooling ...@@ -175,7 +195,7 @@ struct pooling
std::string name() const { return "pooling"; } std::string name() const { return "pooling"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1).only_dims(4); check_shapes{inputs, *this}.has(1).only_dims(4);
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto t = input.type(); auto t = input.type();
...@@ -197,7 +217,10 @@ struct pooling ...@@ -197,7 +217,10 @@ struct pooling
}}; }};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const pooling& op) friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{ {
...@@ -216,11 +239,14 @@ struct activation ...@@ -216,11 +239,14 @@ struct activation
std::string name() const { return "activation"; } std::string name() const { return "activation"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
return inputs.front(); return inputs.front();
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const activation& op) friend std::ostream& operator<<(std::ostream& os, const activation& op)
{ {
os << op.name() << ":" << op.mode; os << op.name() << ":" << op.mode;
...@@ -234,20 +260,20 @@ struct transpose ...@@ -234,20 +260,20 @@ struct transpose
std::string name() const { return "transpose"; } std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
auto input_lens = input.lens(); auto input_lens = input.lens();
auto input_strides = input.strides(); auto input_strides = input.strides();
auto t = input.type(); auto t = input.type();
if(dims.size() != input_lens.size()) if(dims.size() != input_lens.size())
{ {
RTG_THROW("Permutation has wrong number of axes"); MIGRAPH_THROW("Permutation has wrong number of axes");
} }
std::vector<int64_t> axes(dims.size()); std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
if(!std::is_permutation(axes.begin(), axes.end(), dims.begin())) if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
{ {
RTG_THROW("Invalid permutation"); MIGRAPH_THROW("Invalid permutation");
} }
std::vector<size_t> output_lens(input_lens.size()); std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size()); std::vector<size_t> output_strides(input_lens.size());
...@@ -258,7 +284,10 @@ struct transpose ...@@ -258,7 +284,10 @@ struct transpose
} }
return {t, output_lens, output_strides}; return {t, output_lens, output_strides};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct contiguous struct contiguous
...@@ -266,16 +295,19 @@ struct contiguous ...@@ -266,16 +295,19 @@ struct contiguous
std::string name() const { return "contiguous"; } std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs, *this}.has(1);
auto lens = inputs.at(0).lens(); auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
if(lens.size() < 2) if(lens.size() < 2)
{ {
RTG_THROW("Number of dimensions should exceed 1"); MIGRAPH_THROW("Number of dimensions should exceed 1");
} }
return {t, lens}; return {t, lens};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct reshape struct reshape
...@@ -284,8 +316,7 @@ struct reshape ...@@ -284,8 +316,7 @@ struct reshape
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) check_shapes{inputs, *this}.has(1);
RTG_THROW("Wrong number of arguments");
auto&& idims = inputs.front().lens(); auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end()); std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++) for(std::size_t i = 0; i < dims.size(); i++)
...@@ -299,11 +330,15 @@ struct reshape ...@@ -299,11 +330,15 @@ struct reshape
std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims)); std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
} }
shape s{inputs.front().type(), rdims}; shape s{inputs.front().type(), rdims};
if (s.elements() != inputs.front().elements()) RTG_THROW("Wrong number of elements"); if(s.elements() != inputs.front().elements())
MIGRAPH_THROW("Wrong number of elements for reshape");
return s; return s;
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const reshape& op) friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{ {
...@@ -319,17 +354,20 @@ struct gemm ...@@ -319,17 +354,20 @@ struct gemm
std::string name() const { return "gemm"; } std::string name() const { return "gemm"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_ndims().only_dims(2); check_shapes{inputs, *this}.has(2).same_type();
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if(a.lens()[1] != b.lens()[0]) if(a.lens()[1] != b.lens()[0])
RTG_THROW("Inner dimensions do not match"); MIGRAPH_THROW("Inner dimensions do not match");
return {t, {a.lens()[0], b.lens()[1]}}; return {t, {a.lens()[0], b.lens()[1]}};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const gemm& op) friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{ {
...@@ -346,7 +384,10 @@ struct unary ...@@ -346,7 +384,10 @@ struct unary
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); return inputs.at(0);
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct identity : unary struct identity : unary
...@@ -435,19 +476,19 @@ struct broadcast ...@@ -435,19 +476,19 @@ struct broadcast
result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; })) result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; }))
{ {
if(axis != 0) if(axis != 0)
RTG_THROW("when broadcasting tensor of size 1, axis should be 0"); MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0");
return {t, result.lens(), std::move(bcast_strides)}; return {t, result.lens(), std::move(bcast_strides)};
} }
else else
{ {
assert(result.lens().size() - axis >= input.lens().size()); assert(result.lens().size() - axis >= input.lens().size());
if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin() + axis)) if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin() + axis))
RTG_THROW("when broadcasting success sizes must match"); MIGRAPH_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, result.lens(), std::move(bcast_strides)}; return {t, result.lens(), std::move(bcast_strides)};
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {output_shape, std::move(args.at(1).data)}; return {output_shape, std::move(args.at(1).data)};
} }
...@@ -461,7 +502,10 @@ struct binary ...@@ -461,7 +502,10 @@ struct binary
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0); return inputs.at(0);
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct add : binary struct add : binary
...@@ -490,12 +534,26 @@ struct outline ...@@ -490,12 +534,26 @@ struct outline
std::string name() const { return "outline"; } std::string name() const { return "outline"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(0); check_shapes{inputs, *this}.has(0);
return s; return s;
} }
argument compute(shape, std::vector<argument>) const { return {s, nullptr}; } argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; }
};
template <class T>
struct check_context
{
std::string name() const { return "check_context"; }
shape compute_shape(std::vector<shape>) const { return {}; }
argument compute(context& ctx, shape, std::vector<argument>) const
{
T* x = any_cast<T>(&ctx);
if(x == nullptr)
MIGRAPH_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
return {};
}
}; };
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_RTGLIB_PROGRAM_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_PROGRAM_HPP
#define RTG_GUARD_RTGLIB_PROGRAM_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_PROGRAM_HPP
#include <list> #include <list>
#include <unordered_map> #include <unordered_map>
#include <rtg/operation.hpp> #include <migraph/operation.hpp>
#include <rtg/literal.hpp> #include <migraph/literal.hpp>
#include <rtg/builtin.hpp> #include <migraph/builtin.hpp>
#include <rtg/instruction_ref.hpp> #include <migraph/instruction_ref.hpp>
#include <rtg/target.hpp> #include <migraph/target.hpp>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
namespace rtg { namespace migraph {
struct program_impl; struct program_impl;
...@@ -27,6 +27,8 @@ struct program ...@@ -27,6 +27,8 @@ struct program
program& operator=(program&&) noexcept; program& operator=(program&&) noexcept;
~program() noexcept; ~program() noexcept;
using parameter_map = std::unordered_map<std::string, argument>;
template <class... Ts> template <class... Ts>
instruction_ref add_instruction(operation op, Ts... args) instruction_ref add_instruction(operation op, Ts... args)
{ {
...@@ -64,7 +66,7 @@ struct program ...@@ -64,7 +66,7 @@ struct program
shape get_parameter_shape(std::string name); shape get_parameter_shape(std::string name);
argument eval(std::unordered_map<std::string, argument> params) const; argument eval(parameter_map params) const;
friend std::ostream& operator<<(std::ostream& os, const program& p); friend std::ostream& operator<<(std::ostream& os, const program& p);
...@@ -81,6 +83,6 @@ struct program ...@@ -81,6 +83,6 @@ struct program
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
}; };
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_RTGLIB_RANGES_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP
#define RTG_GUARD_RTGLIB_RANGES_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP
namespace rtg { namespace migraph {
template <class C, class T> template <class C, class T>
bool contains(C&& c, T&& x) bool contains(C&& c, T&& x)
...@@ -15,6 +15,6 @@ void copy(Range&& r, Iterator it) ...@@ -15,6 +15,6 @@ void copy(Range&& r, Iterator it)
std::copy(r.begin(), r.end(), it); std::copy(r.begin(), r.end(), it);
} }
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_RAW_DATA_HPP #ifndef MIGRAPH_GUARD_RAW_DATA_HPP
#define RTG_GUARD_RAW_DATA_HPP #define MIGRAPH_GUARD_RAW_DATA_HPP
#include <rtg/tensor_view.hpp> #include <migraph/tensor_view.hpp>
#include <rtg/requires.hpp> #include <migraph/requires.hpp>
namespace rtg { namespace migraph {
struct raw_data_base struct raw_data_base
{ {
...@@ -89,13 +89,27 @@ struct raw_data : raw_data_base ...@@ -89,13 +89,27 @@ struct raw_data : raw_data_base
assert(self->single()); assert(self->single());
return self->template at<T>(); return self->template at<T>();
} }
template <class T>
using is_data_ptr =
bool_c<(std::is_void<T>{} or std::is_same<char, std::remove_cv_t<T>>{} or
std::is_same<unsigned char, std::remove_cv_t<T>>{})>;
template <class T>
using get_data_type = std::conditional_t<is_data_ptr<T>{}, float, T>;
template <class T>
bool matches() const
{
return is_data_ptr<T>{} ||
self->get_shape().type() == migraph::shape::get_type<get_data_type<T>>{};
}
template <class T> template <class T>
operator T*() operator T*()
{ {
using type = std::remove_cv_t<T>; using type = std::remove_cv_t<T>;
assert((std::is_void<T>{} or std::is_same<char, type>{} or assert(matches<T>());
std::is_same<unsigned char, type>{} or
self->get_shape().type() == rtg::shape::get_type<T>{}));
return reinterpret_cast<type*>(self->data()); return reinterpret_cast<type*>(self->data());
} }
}; };
...@@ -109,15 +123,26 @@ struct raw_data : raw_data_base ...@@ -109,15 +123,26 @@ struct raw_data : raw_data_base
{ {
auto&& s = static_cast<const Derived&>(*this).get_shape(); auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data(); auto&& buffer = static_cast<const Derived&>(*this).data();
if(s.type() != rtg::shape::get_type<T>{}) if(s.type() != migraph::shape::get_type<T>{})
RTG_THROW("Incorrect data type for raw data"); MIGRAPH_THROW("Incorrect data type for raw data");
return make_view(s, reinterpret_cast<T*>(buffer)); return make_view(s, reinterpret_cast<T*>(buffer));
} }
/// Cast the data pointer
template <class T>
T* cast() const
{
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
assert(s.type() == migraph::shape::get_type<T>{});
return reinterpret_cast<T*>(buffer);
}
}; };
template <class T, template <class T,
class U, class U,
RTG_REQUIRES(std::is_base_of<raw_data_base, T>{} && std::is_base_of<raw_data_base, U>{})> MIGRAPH_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y) bool operator==(const T& x, const U& y)
{ {
auto&& xshape = x.get_shape(); auto&& xshape = x.get_shape();
...@@ -139,7 +164,8 @@ bool operator==(const T& x, const U& y) ...@@ -139,7 +164,8 @@ bool operator==(const T& x, const U& y)
template <class T, template <class T,
class U, class U,
RTG_REQUIRES(std::is_base_of<raw_data_base, T>{} && std::is_base_of<raw_data_base, U>{})> MIGRAPH_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y) bool operator!=(const T& x, const U& y)
{ {
return !(x == y); return !(x == y);
...@@ -170,13 +196,13 @@ auto visit_all(T&& x, Ts&&... xs) ...@@ -170,13 +196,13 @@ auto visit_all(T&& x, Ts&&... xs)
auto&& s = x.get_shape(); auto&& s = x.get_shape();
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...}; std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); })) if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
RTG_THROW("Types must be the same"); MIGRAPH_THROW("Types must be the same");
return [&](auto v) { return [&](auto v) {
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100 // Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
detail::visit_all_impl(s, v, x, xs...); detail::visit_all_impl(s, v, x, xs...);
}; };
} }
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#include <type_traits>
namespace migraph {
template <bool... Bs>
struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
{
};
template <bool B>
using bool_c = std::integral_constant<bool, B>;
#ifdef CPPCHECK
#define MIGRAPH_REQUIRES(...) class = void
#else
#define MIGRAPH_REQUIRES(...) \
bool PrivateRequires##__LINE__ = true, \
class = typename std::enable_if<and_<__VA_ARGS__, PrivateRequires##__LINE__>{}>::type
#endif
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_SHAPE_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#define RTG_GUARD_RTGLIB_SHAPE_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#include <vector> #include <vector>
#include <cassert> #include <cassert>
#include <ostream> #include <ostream>
#include <numeric> #include <numeric>
#include <rtg/errors.hpp> #include <migraph/errors.hpp>
namespace rtg { namespace migraph {
struct shape struct shape
{ {
// Add new types here // Add new types here
// clang-format off // clang-format off
#define RTG_SHAPE_VISIT_TYPES(m) \ #define MIGRAPH_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \ m(float_type, float) \
m(double_type, double) \ m(double_type, double) \
m(uint8_type, uint8_t) \ m(uint8_type, uint8_t) \
...@@ -28,25 +28,22 @@ struct shape ...@@ -28,25 +28,22 @@ struct shape
m(uint64_type, uint64_t) m(uint64_type, uint64_t)
// clang-format on // clang-format on
#define RTG_SHAPE_ENUM_TYPES(x, t) x, #define MIGRAPH_SHAPE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
{ {
any_type, MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_ENUM_TYPES)
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES)
}; };
#undef RTG_SHAPE_ENUM_TYPES #undef MIGRAPH_SHAPE_ENUM_TYPES
template <class T, class = void> template <class T, class = void>
struct get_type : std::integral_constant<type_t, any_type> struct get_type;
{ #define MIGRAPH_SHAPE_GET_TYPE(x, t) \
};
#define RTG_SHAPE_GET_TYPE(x, t) \
template <class T> \ template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \ struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \ { \
}; };
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_GET_TYPE) MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE)
#undef RTG_SHAPE_GET_TYPE #undef MIGRAPH_SHAPE_GET_TYPE
shape(); shape();
shape(type_t t); shape(type_t t);
...@@ -74,6 +71,7 @@ struct shape ...@@ -74,6 +71,7 @@ struct shape
std::size_t index(std::size_t i) const; std::size_t index(std::size_t i) const;
bool packed() const; bool packed() const;
bool broadcasted() const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
...@@ -124,13 +122,12 @@ struct shape ...@@ -124,13 +122,12 @@ struct shape
{ {
switch(this->m_type) switch(this->m_type)
{ {
case any_type: RTG_THROW("Cannot visit the any_type"); #define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
#define RTG_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return; case x: v(as<t>()); return;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE) MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_VISITOR_CASE)
#undef RTG_SHAPE_VISITOR_CASE #undef MIGRAPH_SHAPE_VISITOR_CASE
} }
RTG_THROW("Unknown type"); MIGRAPH_THROW("Unknown type");
} }
private: private:
...@@ -144,6 +141,6 @@ struct shape ...@@ -144,6 +141,6 @@ struct shape
std::string type_string() const; std::string type_string() const;
}; };
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_RTGLIB_SHAPE_FOR_EACH_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define RTG_GUARD_RTGLIB_SHAPE_FOR_EACH_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#include <rtg/shape.hpp> #include <migraph/shape.hpp>
#include <algorithm> #include <algorithm>
namespace rtg { namespace migraph {
template <class F> template <class F>
void shape_for_each(const rtg::shape& s, F f) void shape_for_each(const migraph::shape& s, F f)
{ {
// Ensure calls to f use const ref to vector // Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); }; auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
...@@ -26,6 +26,6 @@ void shape_for_each(const rtg::shape& s, F f) ...@@ -26,6 +26,6 @@ void shape_for_each(const rtg::shape& s, F f)
} }
} }
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_STREAMUTILS_HPP #ifndef MIGRAPH_GUARD_STREAMUTILS_HPP
#define RTG_GUARD_STREAMUTILS_HPP #define MIGRAPH_GUARD_STREAMUTILS_HPP
#include <ostream> #include <ostream>
#include <algorithm> #include <algorithm>
namespace rtg { namespace migraph {
template <class T> template <class T>
struct stream_range_container struct stream_range_container
...@@ -31,6 +31,6 @@ inline stream_range_container<Range> stream_range(const Range& r) ...@@ -31,6 +31,6 @@ inline stream_range_container<Range> stream_range(const Range& r)
return {r}; return {r};
} }
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_RTGLIB_STRINGUTILS_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define RTG_GUARD_RTGLIB_STRINGUTILS_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include <string> #include <string>
#include <sstream> #include <sstream>
namespace rtg { namespace migraph {
inline std::string inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace) replace_string(std::string subject, const std::string& search, const std::string& replace)
...@@ -77,6 +77,6 @@ inline std::string to_string(const Range& r) ...@@ -77,6 +77,6 @@ inline std::string to_string(const Range& r)
return ss.str(); return ss.str();
} }
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_RTGLIB_TARGET_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#define RTG_GUARD_RTGLIB_TARGET_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#include <string> #include <string>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/context.hpp>
namespace rtg { namespace migraph {
struct program; struct program;
...@@ -18,6 +19,7 @@ struct program; ...@@ -18,6 +19,7 @@ struct program;
* { * {
* std::string name() const; * std::string name() const;
* void apply(program & p) const; * void apply(program & p) const;
* context get_context() const;
* }; * };
* *
*/ */
...@@ -71,6 +73,14 @@ struct target ...@@ -71,6 +73,14 @@ struct target
: nullptr; : nullptr;
} }
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const std::string name() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -83,6 +93,12 @@ struct target ...@@ -83,6 +93,12 @@ struct target
return (*this).private_detail_te_get_handle().apply(p); return (*this).private_detail_te_get_handle().apply(p);
} }
context get_context() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_context();
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
...@@ -92,6 +108,7 @@ struct target ...@@ -92,6 +108,7 @@ struct target
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual void apply(program& p) const = 0; virtual void apply(program& p) const = 0;
virtual context get_context() const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -126,6 +143,8 @@ struct target ...@@ -126,6 +143,8 @@ struct target
void apply(program& p) const override { return private_detail_te_value.apply(p); } void apply(program& p) const override { return private_detail_te_value.apply(p); }
context get_context() const override { return private_detail_te_value.get_context(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
...@@ -139,13 +158,20 @@ struct target ...@@ -139,13 +158,20 @@ struct target
} }
}; };
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{ {
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
...@@ -184,6 +210,6 @@ inline const ValueType& any_cast(const target& x) ...@@ -184,6 +210,6 @@ inline const ValueType& any_cast(const target& x)
return *y; return *y;
} }
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_TENSOR_VIEW_HPP #ifndef MIGRAPH_GUARD_TENSOR_VIEW_HPP
#define RTG_GUARD_TENSOR_VIEW_HPP #define MIGRAPH_GUARD_TENSOR_VIEW_HPP
#include <rtg/shape.hpp> #include <migraph/shape.hpp>
#include <rtg/float_equal.hpp> #include <migraph/float_equal.hpp>
#include <rtg/requires.hpp> #include <migraph/requires.hpp>
#include <iostream> #include <iostream>
namespace rtg { namespace migraph {
template <class T> template <class T>
struct tensor_view struct tensor_view
...@@ -26,27 +26,27 @@ struct tensor_view ...@@ -26,27 +26,27 @@ struct tensor_view
const T* data() const { return this->m_data; } const T* data() const { return this->m_data; }
template <class... Ts, RTG_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const const T& operator()(Ts... xs) const
{ {
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T)); assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
template <class... Ts, RTG_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs) T& operator()(Ts... xs)
{ {
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T)); assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
template <class Iterator, RTG_REQUIRES(not std::is_integral<Iterator>{})> template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})>
const T& operator()(Iterator start, Iterator last) const const T& operator()(Iterator start, Iterator last) const
{ {
return m_data[m_shape.index(start, last)]; return m_data[m_shape.index(start, last)];
} }
template <class Iterator, RTG_REQUIRES(not std::is_integral<Iterator>{})> template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})>
T& operator()(Iterator start, Iterator last) T& operator()(Iterator start, Iterator last)
{ {
return m_data[m_shape.index(start, last)]; return m_data[m_shape.index(start, last)];
...@@ -164,6 +164,6 @@ tensor_view<T> make_view(shape s, T* data) ...@@ -164,6 +164,6 @@ tensor_view<T> make_view(shape s, T* data)
return {s, data}; return {s, data};
} }
} // namespace rtg } // namespace migraph
#endif #endif
#ifndef RTG_GUARD_FALLTHROUGH_HPP
#define RTG_GUARD_FALLTHROUGH_HPP
namespace rtg {
#ifdef __clang__
#define RTG_FALLTHROUGH [[clang::fallthrough]]
#else
#define RTG_FALLTHROUGH
#endif
} // namespace rtg
#endif
#ifndef GUARD_RTGLIB_ONNX_HPP
#define GUARD_RTGLIB_ONNX_HPP
#include <rtg/program.hpp>
namespace rtg {
program parse_onnx(const std::string& name);
} // namespace rtg
#endif
#ifndef RTG_GUARD_RTGLIB_REQUIRES_HPP
#define RTG_GUARD_RTGLIB_REQUIRES_HPP
#include <type_traits>
namespace rtg {
template <bool... Bs>
struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
{
};
#ifdef CPPCHECK
#define RTG_REQUIRES(...) class = void
#else
#define RTG_REQUIRES(...) class = typename std::enable_if<and_<__VA_ARGS__, true>{}>::type
#endif
} // namespace rtg
#endif
...@@ -5,15 +5,23 @@ add_library(onnx-proto STATIC ${PROTO_SRCS}) ...@@ -5,15 +5,23 @@ add_library(onnx-proto STATIC ${PROTO_SRCS})
target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR}) target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR})
target_compile_options(onnx-proto PRIVATE -w) target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(rtg_onnx onnx.cpp) add_library(migraph_onnx onnx.cpp)
rocm_clang_tidy_check(rtg_onnx) rocm_clang_tidy_check(migraph_onnx)
target_link_libraries(rtg_onnx onnx-proto rtg) target_link_libraries(migraph_onnx PRIVATE onnx-proto)
target_link_libraries(migraph_onnx PUBLIC migraph)
add_executable(read_onnx read_onnx.cpp) add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx) rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx rtg_onnx rtg_cpu) target_link_libraries(read_onnx migraph_onnx)
add_executable(mnist mnist.cpp) add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist) rocm_clang_tidy_check(mnist)
target_link_libraries(mnist rtg_onnx rtg_cpu) target_link_libraries(mnist migraph_cpu migraph_onnx)
if(MIGRAPH_ENABLE_MIOPEN)
add_executable(verify_onnx verify_onnx.cpp)
rocm_clang_tidy_check(verify_onnx)
target_link_libraries(verify_onnx migraph_onnx migraph_cpu migraph_miopen)
endif()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment