#ifndef MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP #include #include #include #include #include #include #include namespace migraph { /** * @brief Represents a raw literal * @details This stores the literal has a raw buffer that is owned by this class */ struct literal : raw_data { literal() {} template literal(T x) : buffer(make_shared_array(sizeof(T))), m_shape(shape::get_type{}) { static_assert(std::is_trivial{}, "Literals can only be trivial types"); *(reinterpret_cast(buffer.get())) = x; } template literal(const shape& s, const std::vector& x) : buffer(make_shared_array(s.bytes())), m_shape(s) { static_assert(std::is_trivial{}, "Literals can only be trivial types"); fill(x.begin(), x.end()); } template literal(const shape& s, const std::initializer_list& x) : buffer(make_shared_array(s.bytes())), m_shape(s) { static_assert(std::is_trivial{}, "Literals can only be trivial types"); fill(x.begin(), x.end()); } template literal(const shape& s, Iterator start, Iterator end) : buffer(make_shared_array(s.bytes())), m_shape(s) { fill(start, end); } literal(const shape& s, const char* x) : buffer(make_shared_array(s.bytes())), m_shape(s) { std::copy(x, x + s.bytes(), buffer.get()); } /// Whether data is available bool empty() const { return this->buffer == nullptr; } /// Provides a raw pointer to the data const char* data() const { return this->buffer.get(); } const shape& get_shape() const { return this->m_shape; } /// Convert the data to an argument argument get_argument() const { std::vector b(buffer.get(), buffer.get() + m_shape.bytes()); return {m_shape, [b]() mutable { return b.data(); }}; } private: std::shared_ptr buffer; shape m_shape; template void fill(Iterator start, Iterator end) { if(m_shape.standard()) { m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); }); } else { auto it = start; m_shape.visit_type([&](auto as) { auto output = make_view(m_shape, as.from(buffer.get())); shape_for_each(output.get_shape(), [&](const auto& idx) { it++; output(idx.begin(), idx.end()) = *it; }); }); } } }; template literal transform(literal l, F f) { literal result; l.visit([&](auto x) { using type = std::remove_cv_t; std::vector output(x.size(), 0.0); std::transform(x.begin(), x.end(), output.begin(), f); result = literal{l.get_shape(), output}; }); return result; } } // namespace migraph #endif