Commit bb93c9d5 authored by Paul's avatar Paul
Browse files

Make literals ref counted

parent 8ca97ec3
......@@ -7,6 +7,8 @@
#include <migraph/tensor_view.hpp>
#include <migraph/raw_data.hpp>
#include <memory>
namespace migraph {
/**
......@@ -18,51 +20,54 @@ struct literal : raw_data<literal>
literal() {}
template <class T>
literal(T x) : buffer(sizeof(T), 0), m_shape(shape::get_type<T>{})
literal(T x) : buffer(std::make_unique<char[]>(sizeof(T))), m_shape(shape::get_type<T>{})
{
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
*(reinterpret_cast<T*>(buffer.data())) = x;
*(reinterpret_cast<T*>(buffer.get())) = x;
}
template <class T>
literal(shape s, const std::vector<T>& x) : buffer(s.bytes(), 0), m_shape(s)
literal(shape s, const std::vector<T>& x) : buffer(std::make_unique<char[]>(s.bytes())), m_shape(s)
{
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
fill(x.begin(), x.end());
}
template <class T>
literal(shape s, const std::initializer_list<T>& x) : buffer(s.bytes(), 0), m_shape(s)
literal(shape s, const std::initializer_list<T>& x) : buffer(std::make_unique<char[]>(s.bytes())), m_shape(s)
{
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
fill(x.begin(), x.end());
}
template <class Iterator>
literal(shape s, Iterator start, Iterator end) : buffer(s.bytes(), 0), m_shape(s)
literal(shape s, Iterator start, Iterator end) : buffer(std::make_unique<char[]>(s.bytes())), m_shape(s)
{
fill(start, end);
}
literal(shape s, const char* x) : buffer(x, x + s.bytes()), m_shape(s) {}
literal(shape s, const char* x) : buffer(std::make_unique<char[]>(s.bytes())), m_shape(s)
{
std::copy(x, x + s.bytes(), buffer.get());
}
/// Whether data is available
bool empty() const { return this->buffer.empty(); }
bool empty() const { return this->buffer == nullptr; }
/// Provides a raw pointer to the data
const char* data() const { return this->buffer.data(); }
const char* data() const { return this->buffer.get(); }
const shape& get_shape() const { return this->m_shape; }
/// Convert the data to an argument
argument get_argument() const
{
auto b = buffer;
std::vector<char> b(buffer.get(), buffer.get() + m_shape.bytes());
return {m_shape, [b]() mutable { return b.data(); }};
}
private:
std::vector<char> buffer;
std::shared_ptr<char> buffer;
shape m_shape;
template <class Iterator>
......@@ -70,13 +75,13 @@ struct literal : raw_data<literal>
{
if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.data())); });
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
}
else
{
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.data()));
auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) {
it++;
output(idx.begin(), idx.end()) = *it;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment