#ifndef MIGRAPH_GUARD_RAW_DATA_HPP #define MIGRAPH_GUARD_RAW_DATA_HPP #include #include namespace migraph { struct raw_data_base { }; /** * @brief Provides a base class for common operations with raw buffer * * For classes that handle a raw buffer of data, this will provide common operations such as equals, * printing, and visitors. To use this class the derived class needs to provide a `data()` method to * retrieve a raw pointer to the data, and `get_shape` method that provides the shape of the data. * */ template struct raw_data : raw_data_base { template friend Stream& operator<<(Stream& os, const Derived& d) { d.visit([&](auto x) { os << x; }); return os; } /** * @brief Visits a single data element at a certain index. * * @param v A function which will be called with the type of data * @param n The index to read from */ template void visit_at(Visitor v, std::size_t n = 0) const { auto&& s = static_cast(*this).get_shape(); auto&& buffer = static_cast(*this).data(); s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); }); } /** * @brief Visits the data * * This will call the visitor function with a `tensor_view` based on the shape of the data. * * @param v A function to be called with `tensor_view` */ template void visit(Visitor v) const { auto&& s = static_cast(*this).get_shape(); auto&& buffer = static_cast(*this).data(); s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); }); } /// Returns true if the raw data is only one element bool single() const { auto&& s = static_cast(*this).get_shape(); return s.elements() == 1; } /** * @brief Retrieves a single element of data * * @param n The index to retrieve the data from * @tparam T The type of data to be retrieved * @return The element as `T` */ template T at(std::size_t n = 0) const { T result; this->visit_at([&](auto x) { result = x; }, n); return result; } struct auto_cast { const Derived* self; template operator T() { assert(self->single()); return self->template at(); } template using is_data_ptr = bool_c<(std::is_void{} or std::is_same>{} or std::is_same>{})>; template using get_data_type = std::conditional_t{}, float, T>; template bool matches() const { return is_data_ptr{} || self->get_shape().type() == migraph::shape::get_type>{}; } template operator T*() { using type = std::remove_cv_t; assert(matches()); return reinterpret_cast(self->data()); } }; /// Implicit conversion of raw data pointer auto_cast implicit() const { return {static_cast(this)}; } /// Get a tensor_view to the data template tensor_view get() const { auto&& s = static_cast(*this).get_shape(); auto&& buffer = static_cast(*this).data(); if(s.type() != migraph::shape::get_type{}) MIGRAPH_THROW("Incorrect data type for raw data"); return make_view(s, reinterpret_cast(buffer)); } /// Cast the data pointer template T* cast() const { auto&& s = static_cast(*this).get_shape(); auto&& buffer = static_cast(*this).data(); assert(s.type() == migraph::shape::get_type{}); return reinterpret_cast(buffer); } }; template {} && std::is_base_of{})> bool operator==(const T& x, const U& y) { auto&& xshape = x.get_shape(); auto&& yshape = y.get_shape(); bool result = x.empty() && y.empty(); if(not result && xshape == yshape) { auto&& xbuffer = x.data(); auto&& ybuffer = y.data(); // TODO: Dont use tensor view for single values xshape.visit_type([&](auto as) { auto xview = make_view(xshape, as.from(xbuffer)); auto yview = make_view(yshape, as.from(ybuffer)); result = xview == yview; }); } return result; } template {} && std::is_base_of{})> bool operator!=(const T& x, const U& y) { return !(x == y); } namespace detail { template void visit_all_impl(const shape& s, V&& v, Ts&&... xs) { s.visit_type([&](auto as) { v(make_view(xs.get_shape(), as.from(xs.data()))...); }); } } // namespace detail /** * @brief Visits every object together * @details This will visit every object, but assumes each object is the same type. This can reduce * the deeply nested visit calls. This will return a function that will take the visitor callback. * So it will be called with `visit_all(xs...)([](auto... ys) {})` where `xs...` and `ys...` are the * same number of parameters. * * @param x A raw data object * @param xs Many raw data objects * @return A function to be called with the visitor */ template auto visit_all(T&& x, Ts&&... xs) { auto&& s = x.get_shape(); std::initializer_list types = {xs.get_shape().type()...}; if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); })) MIGRAPH_THROW("Types must be the same"); return [&](auto v) { // Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100 detail::visit_all_impl(s, v, x, xs...); }; } } // namespace migraph #endif