Commit fb75dfaf authored by Paul's avatar Paul
Browse files

Only use no-cache on jenkins

parents e596eec2 f0604d78
#ifndef MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
#define MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
#include <cassert>
#include <type_traits>
namespace migraph {
template <class T>
struct iterator_for_range
{
T* base;
using base_iterator = std::remove_reference_t<decltype(base->begin())>;
struct iterator
{
base_iterator i;
base_iterator operator*() { return i; }
base_iterator operator++() { return ++i; }
bool operator!=(const iterator& rhs) { return i != rhs.i; }
};
iterator begin()
{
assert(base != nullptr);
return {base->begin()};
}
iterator end()
{
assert(base != nullptr);
return {base->end()};
}
};
template <class T>
iterator_for_range<T> iterator_for(T& x)
{
return {&x};
}
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_LITERAL_HPP
#define RTG_GUARD_RTGLIB_LITERAL_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#include <rtg/shape.hpp>
#include <rtg/argument.hpp>
#include <rtg/tensor_view.hpp>
#include <rtg/raw_data.hpp>
#include <migraph/shape.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/argument.hpp>
#include <migraph/tensor_view.hpp>
#include <migraph/raw_data.hpp>
namespace rtg {
namespace migraph {
/**
* @brief Represents a raw literal
......@@ -26,24 +27,21 @@ struct literal : raw_data<literal>
template <class T>
literal(shape s, const std::vector<T>& x) : buffer(s.bytes(), 0), m_shape(s)
{
assert(s.packed());
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
s.visit_type([&](auto as) { std::copy(x.begin(), x.end(), as.from(buffer.data())); });
fill(x.begin(), x.end());
}
template <class T>
literal(shape s, const std::initializer_list<T>& x) : buffer(s.bytes(), 0), m_shape(s)
{
assert(s.packed());
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
s.visit_type([&](auto as) { std::copy(x.begin(), x.end(), as.from(buffer.data())); });
fill(x.begin(), x.end());
}
template <class Iterator>
literal(shape s, Iterator start, Iterator end) : buffer(s.bytes(), 0), m_shape(s)
{
assert(s.packed());
s.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.data())); });
fill(start, end);
}
literal(shape s, const char* x) : buffer(x, x + s.bytes()), m_shape(s) {}
......@@ -66,8 +64,28 @@ struct literal : raw_data<literal>
private:
std::vector<char> buffer;
shape m_shape;
template <class Iterator>
void fill(Iterator start, Iterator end)
{
if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.data())); });
}
else
{
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.data()));
shape_for_each(output.get_shape(), [&](const auto& idx) {
it++;
output(idx.begin(), idx.end()) = *it;
});
});
}
}
};
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTG_MANAGE_PTR_HPP
#define RTG_GUARD_RTG_MANAGE_PTR_HPP
#ifndef MIGRAPH_GUARD_MIGRAPH_MANAGE_PTR_HPP
#define MIGRAPH_GUARD_MIGRAPH_MANAGE_PTR_HPP
#include <memory>
#include <type_traits>
namespace rtg {
namespace migraph {
template <class F, F f> // NOLINT
struct manage_deleter
......@@ -49,8 +49,9 @@ shared<T> share(T p)
return shared<T>{std::move(p)};
}
} // namespace rtg
} // namespace migraph
#define RTG_MANAGE_PTR(T, F) rtg::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
#define MIGRAPH_MANAGE_PTR(T, F) \
migraph::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
#endif
#ifndef GUARD_MIGRAPHLIB_ONNX_HPP
#define GUARD_MIGRAPHLIB_ONNX_HPP
#include <migraph/program.hpp>
namespace migraph {
/// Create a program from an onnx file
program parse_onnx(const std::string& name);
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_OPERAND_HPP
#define RTG_GUARD_RTGLIB_OPERAND_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <rtg/shape.hpp>
#include <rtg/argument.hpp>
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
namespace rtg {
namespace migraph {
#ifdef DOXYGEN
/// The operation interface represents an action an instruction will perform. All
/// operation classes must be CopyConstructible.
struct operation
{
/// A unique name identifying the operation
std::string name() const;
/// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an
/// exception.
shape compute_shape(std::vector<shape> input) const;
/**
* @brief This performs the operation's computation
*
* @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each
* `shape` of the `argument`.
* @param input This is the `argument` result from the previous instuction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape.
*/
argument compute(context& ctx, shape output, std::vector<argument> input) const;
/// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name.
friend std::ostream& operator<<(std::ostream& os, const operation& op);
};
#else
namespace operation_stream {
......@@ -21,6 +55,12 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream
template <class T>
argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<argument> input)
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
/*
* Type-erased interface for:
*
......@@ -28,7 +68,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(shape output,std::vector<argument> input) const;
* argument compute(context& ctx,shape output,std::vector<argument> input) const;
* friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
* };
*
......@@ -83,6 +123,14 @@ struct operation
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -95,10 +143,11 @@ struct operation
return (*this).private_detail_te_get_handle().compute_shape(std::move(input));
}
argument compute(shape output, std::vector<argument> input) const
argument compute(context& ctx, shape output, std::vector<argument> input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(std::move(output), std::move(input));
return (*this).private_detail_te_get_handle().compute(
ctx, std::move(output), std::move(input));
}
friend std::ostream& operator<<(std::ostream& os, const operation& op)
......@@ -114,10 +163,10 @@ struct operation
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(shape output, std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(context& ctx, shape output, std::vector<argument> input) const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -156,15 +205,15 @@ struct operation
return private_detail_te_value.compute_shape(std::move(input));
}
argument compute(shape output, std::vector<argument> input) const override
argument compute(context& ctx, shape output, std::vector<argument> input) const override
{
return private_detail_te_value.compute(std::move(output), std::move(input));
return compute_op(private_detail_te_value, ctx, std::move(output), std::move(input));
}
std::ostream& operator_shift_left(std::ostream& os) const override
{
using rtg::operation_stream::operator<<;
using migraph::operation_stream::operator<<;
return os << private_detail_te_value;
}
......@@ -181,13 +230,20 @@ struct operation
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
......@@ -226,6 +282,8 @@ inline const ValueType& any_cast(const operation& x)
return *y;
}
} // namespace rtg
#endif
} // namespace migraph
#endif
#ifndef RTG_GUARD_OPERATORS_HPP
#define RTG_GUARD_OPERATORS_HPP
#ifndef MIGRAPH_GUARD_OPERATORS_HPP
#define MIGRAPH_GUARD_OPERATORS_HPP
#include <rtg/operation.hpp>
#include <rtg/stringutils.hpp>
#include <rtg/streamutils.hpp>
#include <array>
#include <migraph/operation.hpp>
#include <migraph/check_shapes.hpp>
#include <migraph/stringutils.hpp>
#include <migraph/streamutils.hpp>
#include <cmath>
namespace rtg {
namespace migraph {
struct check_shapes
struct not_computable
{
const std::vector<shape>* shapes;
check_shapes(const std::vector<shape>& s) : shapes(&s) {}
const check_shapes& has(std::size_t n) const
argument compute(context&, shape, std::vector<argument>) const
{
assert(shapes != nullptr);
if(shapes->size() != n)
RTG_THROW("Wrong number of arguments: expected " + std::to_string(n) + " but given " +
std::to_string(shapes->size()));
return *this;
MIGRAPH_THROW("not computable");
}
};
const check_shapes& only_dims(std::size_t n) const
{
assert(shapes != nullptr);
if(!shapes->empty())
{
if(shapes->front().lens().size() != n)
RTG_THROW("Only " + std::to_string(n) + "d supported");
}
return *this;
}
struct batch_norm_inference
{
float epsilon = 1.0e-6f;
float momentum = 0.9f;
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
RTG_THROW("Shapes do not match");
return *this;
}
std::string name() const { return "batch_norm_inference"; }
const check_shapes& same_type() const
enum bn_infer_mode_t
{
if(!this->same([](const shape& s) { return s.type(); }))
RTG_THROW("Types do not match");
return *this;
}
per_activation,
spatial,
};
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
RTG_THROW("Dimensions do not match");
return *this;
}
bn_infer_mode_t bn_mode = spatial;
bool is_test = false;
template <class F>
bool same(F f) const
shape compute_shape(std::vector<shape> inputs) const
{
assert(shapes != nullptr);
if(shapes->empty())
return true;
auto&& key = f(shapes->front());
return this->all_of([&](const shape& s) { return f(s) == key; });
check_shapes{inputs, *this}.has(5);
return inputs.front();
}
template <class Predicate>
bool all_of(Predicate p) const
argument compute(context&, shape, std::vector<argument>) const
{
assert(shapes != nullptr);
return std::all_of(shapes->begin(), shapes->end(), p);
MIGRAPH_THROW("not computable");
}
};
struct not_computable
{
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct convolution
{
std::array<std::size_t, 2> padding = {{0, 0}};
......@@ -93,7 +62,7 @@ struct convolution
std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims().only_dims(4);
check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
......@@ -141,11 +110,14 @@ struct convolution
}
else
{
RTG_THROW("Invalid padding mode");
MIGRAPH_THROW("Invalid padding mode");
}
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{
......@@ -165,30 +137,38 @@ struct pooling
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> lengths = {{1, 1}};
std::string name() const { return "pooling"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).only_dims(4);
check_shapes{inputs, *this}.has(1).only_dims(4);
const shape& input = inputs.at(0);
auto t = input.type();
assert(lengths[0] < (input.lens()[2] + 2 * padding[0]));
assert(lengths[1] < (input.lens()[3] + 2 * padding[1]));
return {t,
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ceil((input.lens()[3] + 2 * padding[0] - lengths[0]) /
static_cast<float>(stride[0])) +
std::ptrdiff_t(std::ceil((input.lens()[2] + 2 * padding[0] - lengths[0]) /
static_cast<float>(stride[0]))) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ceil((input.lens()[4] + 2 * padding[1] - lengths[1]) /
static_cast<float>(stride[1])) +
std::ptrdiff_t(std::ceil((input.lens()[3] + 2 * padding[1] - lengths[1]) /
static_cast<float>(stride[1]))) +
1)),
}};
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{
......@@ -207,11 +187,14 @@ struct activation
std::string name() const { return "activation"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const activation& op)
{
os << op.name() << ":" << op.mode;
......@@ -219,30 +202,105 @@ struct activation
}
};
struct transpose
{
std::vector<int64_t> dims;
std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0);
auto input_lens = input.lens();
auto input_strides = input.strides();
auto t = input.type();
if(dims.size() != input_lens.size())
{
MIGRAPH_THROW("Permutation has wrong number of axes");
}
std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0);
if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
{
MIGRAPH_THROW("Invalid permutation");
}
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size());
for(int i = 0; i < output_lens.size(); i++)
{
output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]];
}
return {t, output_lens, output_strides};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct contiguous
{
std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
if(lens.size() < 2)
{
MIGRAPH_THROW("Number of dimensions should exceed 1");
}
return {t, lens};
}
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
};
struct reshape
{
std::vector<int64_t> dims;
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty())
RTG_THROW("Wrong number of arguments");
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPH_THROW("Dimensions for reshape can only have one -1 dim");
for(std::size_t i = 0; i < dims.size(); i++)
{
if(dims[i] == 0)
rdims[i] = idims[i];
}
if(n_neg_dims > 0)
{
size_t missing_dim =
-inputs.front().elements() /
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
for(std::size_t i = 0; i < rdims.size(); i++)
{
if(dims[i] == -1)
rdims[i] = missing_dim;
}
}
if(dims.back() == -1)
{
rdims.pop_back();
std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
}
return {inputs.front().type(), rdims};
shape s{inputs.front().type(), rdims};
if(s.elements() != inputs.front().elements())
MIGRAPH_THROW("Wrong number of elements for reshape");
return s;
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{
......@@ -253,18 +311,200 @@ struct reshape
}
};
struct gemm
{
float alpha = 1.0;
float beta = 0.0;
std::string name() const { return "gemm"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type();
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if(a.lens()[1] != b.lens()[0])
MIGRAPH_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) + "} x {" +
to_string_range(b.lens()) + "}");
return {t, {a.lens()[0], b.lens()[1]}};
}
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{
os << op.name() << "[";
os << "]";
return os;
}
};
struct unary
{
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
};
struct identity : unary
{
std::string name() const { return "identity"; }
};
struct abs : unary
{
std::string name() const { return "abs"; }
};
struct exp : unary
{
std::string name() const { return "exp"; }
};
struct sin : unary
{
std::string name() const { return "sin"; }
};
struct cos : unary
{
std::string name() const { return "cos"; }
};
struct tan : unary
{
std::string name() const { return "tan"; }
};
struct asin : unary
{
std::string name() const { return "asin"; }
};
struct acos : unary
{
std::string name() const { return "acos"; }
};
struct atan : unary
{
std::string name() const { return "atan"; }
};
struct softmax : unary
{
std::string name() const { return "softmax"; }
};
struct tanh : unary
{
std::string name() const { return "tanh"; }
};
struct sigmoid : unary
{
std::string name() const { return "sigmoid"; }
};
struct neg : unary
{
std::string name() const { return "neg"; }
};
struct flatten
{
std::string name() const { return "flatten"; }
};
struct broadcast
{
uint64_t axis = 0;
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto t = inputs.at(0).type();
auto result = inputs.at(0);
auto input = inputs.at(1);
std::vector<size_t> bcast_strides(result.lens().size(), 0);
if(std::all_of(
result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; }))
{
if(axis != 0)
MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0");
return {t, result.lens(), std::move(bcast_strides)};
}
else
{
assert(result.lens().size() - axis >= input.lens().size());
if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin() + axis))
MIGRAPH_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, result.lens(), std::move(bcast_strides)};
}
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.at(1).data)};
}
};
struct binary
{
uint64_t broadcast = 0;
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
};
struct add : binary
{
std::string name() const { return "add"; }
};
struct sub : binary
{
std::string name() const { return "sub"; }
};
struct mul : binary
{
std::string name() const { return "mul"; }
};
struct div : binary
{
std::string name() const { return "div"; }
};
struct outline
{
shape s;
std::string name() const { return "outline"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(0);
check_shapes{inputs, *this}.has(0);
return s;
}
argument compute(shape, std::vector<argument>) const { return {s, nullptr}; }
argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; }
};
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_TARGET_HPP
#define RTG_GUARD_RTGLIB_TARGET_HPP
#ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace rtg {
namespace migraph {
struct program;
#ifdef DOXYGEN
/// An interface for applying a transformation to the instructions in a
/// `program`
struct pass
{
/// A unique name used to identify the pass
std::string name() const;
/// Run the pass on the program
void apply(program& p) const;
};
#else
/*
* Type-erased interface for:
*
* struct target
* struct pass
* {
* std::string name() const;
* void apply(program & p) const;
......@@ -22,13 +37,13 @@ struct program;
*
*/
struct target
struct pass
{
// Constructors
target() = default;
pass() = default;
template <typename PrivateDetailTypeErasedT>
target(PrivateDetailTypeErasedT value)
pass(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
......@@ -38,7 +53,7 @@ struct target
// Assignment
template <typename PrivateDetailTypeErasedT>
target& operator=(PrivateDetailTypeErasedT value)
pass& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
......@@ -71,6 +86,14 @@ struct target
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -139,13 +162,20 @@ struct target
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
......@@ -155,19 +185,19 @@ struct target
};
template <typename ValueType>
inline const ValueType* any_cast(const target* x)
inline const ValueType* any_cast(const pass* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(target* x)
inline ValueType* any_cast(pass* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(target& x)
inline ValueType& any_cast(pass& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
......@@ -176,7 +206,7 @@ inline ValueType& any_cast(target& x)
}
template <typename ValueType>
inline const ValueType& any_cast(const target& x)
inline const ValueType& any_cast(const pass& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
......@@ -184,6 +214,8 @@ inline const ValueType& any_cast(const target& x)
return *y;
}
} // namespace rtg
#endif
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_PROGRAM_HPP
#define RTG_GUARD_RTGLIB_PROGRAM_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_PROGRAM_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_PROGRAM_HPP
#include <list>
#include <unordered_map>
#include <rtg/operation.hpp>
#include <rtg/literal.hpp>
#include <rtg/builtin.hpp>
#include <rtg/instruction_ref.hpp>
#include <rtg/target.hpp>
#include <migraph/operation.hpp>
#include <migraph/literal.hpp>
#include <migraph/builtin.hpp>
#include <migraph/instruction_ref.hpp>
#include <migraph/target.hpp>
#include <algorithm>
#include <iostream>
namespace rtg {
namespace migraph {
struct program_impl;
const operation& get_operation(instruction_ref ins);
/**
* @brief Stores the instruction stream
*/
......@@ -25,6 +27,8 @@ struct program
program& operator=(program&&) noexcept;
~program() noexcept;
using parameter_map = std::unordered_map<std::string, argument>;
template <class... Ts>
instruction_ref add_instruction(operation op, Ts... args)
{
......@@ -48,6 +52,13 @@ struct program
instruction_ref
replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args);
instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep);
instruction_ref remove_instruction(instruction_ref ins);
instruction_ref remove_instructions(instruction_ref first, instruction_ref last);
instruction_ref move_instruction(instruction_ref src, instruction_ref dst);
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
{
......@@ -60,25 +71,31 @@ struct program
instruction_ref add_parameter(std::string name, shape s);
shape get_parameter_shape(std::string name);
shape get_parameter_shape(std::string name) const;
argument eval(std::unordered_map<std::string, argument> params) const;
std::unordered_map<std::string, shape> get_parameter_shapes() const;
friend std::ostream& operator<<(std::ostream& os, const program& p);
argument eval(parameter_map params) const;
bool has_instruction(instruction_ref ins) const;
instruction_ref begin();
instruction_ref end();
instruction_ref begin() const;
instruction_ref end() const;
shape get_shape() const;
instruction_ref validate() const;
void compile(const target& t);
friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
private:
std::unique_ptr<program_impl> impl;
};
} // namespace rtg
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP
#include <algorithm>
namespace migraph {
template <class C, class T>
bool contains(C&& c, T&& x)
{
return c.find(x) != c.end();
}
template <class Range, class Iterator>
void copy(Range&& r, Iterator it)
{
std::copy(r.begin(), r.end(), it);
}
template <class Iterator>
struct iterator_range
{
Iterator start;
Iterator last;
Iterator begin() const { return start; }
Iterator end() const { return last; }
};
template <class Iterator>
iterator_range<Iterator> range(Iterator start, Iterator last)
{
return {start, last};
}
} // namespace migraph
#endif
#ifndef RTG_GUARD_RAW_DATA_HPP
#define RTG_GUARD_RAW_DATA_HPP
#ifndef MIGRAPH_GUARD_RAW_DATA_HPP
#define MIGRAPH_GUARD_RAW_DATA_HPP
#include <rtg/tensor_view.hpp>
#include <migraph/tensor_view.hpp>
#include <migraph/requires.hpp>
namespace rtg {
#define RTG_REQUIRES(...) class = typename std::enable_if<(__VA_ARGS__)>::type
namespace migraph {
struct raw_data_base
{
......@@ -90,13 +89,27 @@ struct raw_data : raw_data_base
assert(self->single());
return self->template at<T>();
}
template <class T>
using is_data_ptr =
bool_c<(std::is_void<T>{} or std::is_same<char, std::remove_cv_t<T>>{} or
std::is_same<unsigned char, std::remove_cv_t<T>>{})>;
template <class T>
using get_data_type = std::conditional_t<is_data_ptr<T>{}, float, T>;
template <class T>
bool matches() const
{
return is_data_ptr<T>{} ||
self->get_shape().type() == migraph::shape::get_type<get_data_type<T>>{};
}
template <class T>
operator T*()
{
using type = std::remove_cv_t<T>;
assert((std::is_void<T>{} or std::is_same<char, type>{} or
std::is_same<unsigned char, type>{} or
self->get_shape().type() == rtg::shape::get_type<T>{}));
assert(matches<T>());
return reinterpret_cast<type*>(self->data());
}
};
......@@ -110,15 +123,26 @@ struct raw_data : raw_data_base
{
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
if(s.type() != rtg::shape::get_type<T>{})
RTG_THROW("Incorrect data type for raw data");
if(s.type() != migraph::shape::get_type<T>{})
MIGRAPH_THROW("Incorrect data type for raw data");
return make_view(s, reinterpret_cast<T*>(buffer));
}
/// Cast the data pointer
template <class T>
T* cast() const
{
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
assert(s.type() == migraph::shape::get_type<T>{});
return reinterpret_cast<T*>(buffer);
}
};
template <class T,
class U,
RTG_REQUIRES(std::is_base_of<raw_data_base, T>{} && std::is_base_of<raw_data_base, U>{})>
MIGRAPH_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{
auto&& xshape = x.get_shape();
......@@ -140,7 +164,8 @@ bool operator==(const T& x, const U& y)
template <class T,
class U,
RTG_REQUIRES(std::is_base_of<raw_data_base, T>{} && std::is_base_of<raw_data_base, U>{})>
MIGRAPH_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{
return !(x == y);
......@@ -171,13 +196,13 @@ auto visit_all(T&& x, Ts&&... xs)
auto&& s = x.get_shape();
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
RTG_THROW("Types must be the same");
MIGRAPH_THROW("Types must be the same");
return [&](auto v) {
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
detail::visit_all_impl(s, v, x, xs...);
};
}
} // namespace rtg
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#include <type_traits>
namespace migraph {
template <bool... Bs>
struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
{
};
template <bool B>
using bool_c = std::integral_constant<bool, B>;
#ifdef CPPCHECK
#define MIGRAPH_REQUIRES(...) class = void
#else
#define MIGRAPH_REQUIRES(...) \
bool PrivateRequires##__LINE__ = true, \
class = typename std::enable_if<and_<__VA_ARGS__, PrivateRequires##__LINE__>{}>::type
#endif
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_SHAPE_HPP
#define RTG_GUARD_RTGLIB_SHAPE_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#include <vector>
#include <cassert>
#include <ostream>
#include <numeric>
#include <rtg/errors.hpp>
#include <migraph/errors.hpp>
namespace rtg {
namespace migraph {
struct shape
{
// Add new types here
// clang-format off
#define RTG_SHAPE_VISIT_TYPES(m) \
#define MIGRAPH_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \
m(double_type, double) \
m(uint8_type, uint8_t) \
......@@ -27,25 +28,27 @@ struct shape
m(uint64_type, uint64_t)
// clang-format on
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
#define MIGRAPH_SHAPE_ENUM_TYPES(x, t) x,
enum type_t
{
any_type,
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES)
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_ENUM_TYPES)
};
#undef RTG_SHAPE_ENUM_TYPES
#undef MIGRAPH_SHAPE_ENUM_TYPES
template <class T, class = void>
struct get_type : std::integral_constant<type_t, any_type>
{
};
#define RTG_SHAPE_GET_TYPE(x, t) \
struct get_type;
#define MIGRAPH_SHAPE_GET_TYPE(x, t) \
template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \
};
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_GET_TYPE)
#undef RTG_SHAPE_GET_TYPE
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE)
#undef MIGRAPH_SHAPE_GET_TYPE
template <class T>
struct get_type<const T> : get_type<T>
{
};
shape();
shape(type_t t);
......@@ -58,13 +61,33 @@ struct shape
std::size_t elements() const;
std::size_t bytes() const;
/// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const;
/// Map multiple indices to space index
std::size_t index(const std::vector<std::size_t>& l) const;
// Map element index to space index
/// Map multiple indices from a range of iterator to a space index
template <class Iterator>
std::size_t index(Iterator start, Iterator last) const
{
assert(std::distance(start, last) <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(start, last, this->strides().begin(), std::size_t{0});
}
/// Map element index to space index
std::size_t index(std::size_t i) const;
/// Returns true if the shape is packed with no padding
bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// order
bool transposed() const;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool broadcasted() const;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed.
bool standard() const;
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
......@@ -115,26 +138,25 @@ struct shape
{
switch(this->m_type)
{
case any_type: RTG_THROW("Cannot visit the any_type");
#define RTG_SHAPE_VISITOR_CASE(x, t) \
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE)
#undef RTG_SHAPE_VISITOR_CASE
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_VISITOR_CASE)
#undef MIGRAPH_SHAPE_VISITOR_CASE
}
RTG_THROW("Unknown type");
MIGRAPH_THROW("Unknown type");
}
private:
type_t m_type;
std::vector<std::size_t> m_lens;
std::vector<std::size_t> m_strides;
bool m_packed;
bool m_standard;
void calculate_strides();
std::size_t element_space() const;
std::string type_string() const;
};
} // namespace rtg
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#include <migraph/shape.hpp>
#include <algorithm>
namespace migraph {
template <class F>
void shape_for_each(const migraph::shape& s, F f)
{
// Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
std::vector<std::size_t> indices(s.lens().size());
for(std::size_t i = 0; i < s.elements(); i++)
{
std::transform(s.strides().begin(),
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (i / stride) % len;
});
call(indices);
}
}
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct simplify_reshapes
{
std::string name() const { return "simplify_reshapes"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
#ifndef RTG_GUARD_STREAMUTILS_HPP
#define RTG_GUARD_STREAMUTILS_HPP
#ifndef MIGRAPH_GUARD_STREAMUTILS_HPP
#define MIGRAPH_GUARD_STREAMUTILS_HPP
#include <ostream>
#include <algorithm>
namespace rtg {
namespace migraph {
template <class T>
struct stream_range_container
......@@ -31,6 +31,6 @@ inline stream_range_container<Range> stream_range(const Range& r)
return {r};
}
} // namespace rtg
} // namespace migraph
#endif
#ifndef RTG_GUARD_RTGLIB_STRINGUTILS_HPP
#define RTG_GUARD_RTGLIB_STRINGUTILS_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#include <algorithm>
#include <numeric>
#include <string>
#include <sstream>
namespace rtg {
namespace migraph {
inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace)
......@@ -66,7 +66,7 @@ inline std::string remove_prefix(std::string s, std::string prefix)
}
template <class Range>
inline std::string to_string(const Range& r)
inline std::string to_string_range(const Range& r)
{
std::stringstream ss;
if(!r.empty())
......@@ -77,6 +77,14 @@ inline std::string to_string(const Range& r)
return ss.str();
}
} // namespace rtg
template <class T>
inline std::string to_string(const T& x)
{
std::stringstream ss;
ss << x;
return ss.str();
}
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include <migraph/context.hpp>
#include <migraph/pass.hpp>
namespace migraph {
#ifdef DOXYGEN
/// An interface for a compilation target
struct target
{
/// A unique name used to identify the target
std::string name() const;
/// The transformation passes to be run
/**
* @brief The transformation pass to be run during compilation.
* @details [long description]
*
* @param ctx This is the target-dependent context that is created by `get_context`
* @return The passes to be ran
*/
std::vector<pass> get_passes(context& ctx) const;
/**
* @brief Construct a context for the target.
* @return The context to be used during compilation and execution.
*/
context get_context() const;
};
#else
/*
* Type-erased interface for:
*
* struct target
* {
* std::string name() const;
* std::vector<pass> get_passes(context& ctx) const;
* context get_context() const;
* };
*
*/
struct target
{
// Constructors
target() = default;
template <typename PrivateDetailTypeErasedT>
target(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
target& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().name();
}
std::vector<pass> get_passes(context& ctx) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_passes(ctx);
}
context get_context() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_context();
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual std::vector<pass> get_passes(context& ctx) const = 0;
virtual context get_context() const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
std::string name() const override { return private_detail_te_value.name(); }
std::vector<pass> get_passes(context& ctx) const override
{
return private_detail_te_value.get_passes(ctx);
}
context get_context() const override { return private_detail_te_value.get_context(); }
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const target* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(target* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(target& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const target& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
} // namespace migraph
#endif
#ifndef RTG_GUARD_TENSOR_VIEW_HPP
#define RTG_GUARD_TENSOR_VIEW_HPP
#ifndef MIGRAPH_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH_GUARD_TENSOR_VIEW_HPP
#include <rtg/shape.hpp>
#include <rtg/float_equal.hpp>
#include <migraph/shape.hpp>
#include <migraph/float_equal.hpp>
#include <migraph/requires.hpp>
#include <iostream>
namespace rtg {
namespace migraph {
template <class T>
struct tensor_view
{
using value_type = T;
tensor_view() : m_data(nullptr) {}
tensor_view(shape s, T* d) : m_data(d), m_shape(s) {}
......@@ -24,18 +26,34 @@ struct tensor_view
const T* data() const { return this->m_data; }
template <class... Ts>
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const
{
return m_data[m_shape.index({xs...})];
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
}
template <class... Ts>
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs)
{
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
assert(m_shape.index({static_cast<std::size_t>(xs)...}) < m_shape.bytes() / sizeof(T));
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
}
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})>
const T& operator()(Iterator start, Iterator last) const
{
return m_data[m_shape.index(start, last)];
}
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})>
T& operator()(Iterator start, Iterator last)
{
return m_data[m_shape.index(start, last)];
}
T& operator[](std::size_t i)
{
assert(!this->empty() && i < this->size());
......@@ -72,16 +90,16 @@ struct tensor_view
return m_data[m_shape.index(this->size() - 1)];
}
// TODO: Add iterators so it can handle nonpacked tensors
// TODO: Add iterators so it can handle nonstandard tensors
T* begin()
{
assert(this->m_shape.packed());
assert(this->m_shape.standard());
return m_data;
}
T* end()
{
assert(this->m_shape.packed());
assert(this->m_shape.standard());
if(this->empty())
return m_data;
else
......@@ -90,13 +108,13 @@ struct tensor_view
const T* begin() const
{
assert(this->m_shape.packed());
assert(this->m_shape.standard());
return m_data;
}
const T* end() const
{
assert(this->m_shape.packed());
assert(this->m_shape.standard());
if(this->empty())
return m_data;
else
......@@ -148,6 +166,6 @@ tensor_view<T> make_view(shape s, T* data)
return {s, data};
}
} // namespace rtg
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#include <string>
namespace migraph {
template <class PrivateMigraphTypeNameProbe>
const std::string& get_type_name()
{
static std::string name;
if(name.empty())
{
#ifdef _MSC_VER
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
const char parameter_name[] = "PrivateMigraphTypeNameProbe =";
name = __PRETTY_FUNCTION__;
auto begin = name.find(parameter_name) + sizeof(parameter_name);
#if(defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ < 7)
auto length = name.find_last_of(",") - begin;
#else
auto length = name.find_first_of("];", begin) - begin;
#endif
name = name.substr(begin, length);
#endif
}
return name;
}
template <class T>
const std::string& get_type_name(const T&)
{
return migraph::get_type_name<T>();
}
} // namespace migraph
#endif
#ifndef RTG_GUARD_FALLTHROUGH_HPP
#define RTG_GUARD_FALLTHROUGH_HPP
namespace rtg {
#ifdef __clang__
#define RTG_FALLTHROUGH [[clang::fallthrough]]
#else
#define RTG_FALLTHROUGH
#endif
} // namespace rtg
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment