/* * The MIT License (MIT) * * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #ifndef MIGRAPHX_GUARD_TENSOR_VIEW_HPP #define MIGRAPHX_GUARD_TENSOR_VIEW_HPP #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { template T as_number(T x) { return x; } inline int32_t as_number(int8_t x) { return static_cast(x); } inline uint32_t as_number(uint8_t x) { return static_cast(x); } template struct tensor_view_iterator_read { T* view; auto& operator()(std::size_t n) const { assert(view != nullptr); return (*view)[n]; } }; template struct tensor_view { using value_type = T; using iterator = basic_iota_iterator>, std::size_t>; using const_iterator = basic_iota_iterator>, std::size_t>; tensor_view() : m_data(nullptr) {} tensor_view(shape s, T* d) : m_data(d), m_shape(std::move(s)) {} const shape& get_shape() const { return this->m_shape; } bool empty() const { return m_data == nullptr || m_shape.lens().empty(); } std::size_t size() const { return m_shape.elements(); } T* data() { return this->m_data; } const T* data() const { return this->m_data; } template {}...)> const T& operator()(Ts... xs) const { assert(std::vector{static_cast(xs)...} < m_shape.lens()); assert(m_shape.index({static_cast(xs)...}) < m_shape.bytes() / sizeof(T)); return m_data[m_shape.index({static_cast(xs)...})]; } template {}...)> T& operator()(Ts... xs) { assert(std::vector{static_cast(xs)...} < m_shape.lens()); assert(m_shape.index({static_cast(xs)...}) < m_shape.bytes() / sizeof(T)); return m_data[m_shape.index({static_cast(xs)...})]; } template {})> const T& operator()(Iterator start, Iterator last) const { assert(std::distance(start, last) > 0); assert(std::all_of(start, last, [](auto x) { return x >= 0; })); return m_data[m_shape.index(start, last)]; } template {})> T& operator()(Iterator start, Iterator last) { assert(std::distance(start, last) > 0); assert(std::all_of(start, last, [](auto x) { return x >= 0; })); return m_data[m_shape.index(start, last)]; } T& operator[](std::size_t i) { assert(!this->empty() && i < this->size()); return m_data[m_shape.index(i)]; } const T& operator[](std::size_t i) const { assert(!this->empty() && i < this->size()); return m_data[m_shape.index(i)]; } T& front() { assert(!this->empty()); return m_data[0]; } const T& front() const { assert(!this->empty()); return m_data[0]; } T& back() { assert(!this->empty()); return m_data[m_shape.index(this->size() - 1)]; } const T& back() const { assert(!this->empty()); return m_data[m_shape.index(this->size() - 1)]; } iterator begin() { return {0, {this}}; } iterator end() { return {this->size(), {this}}; } const_iterator begin() const { return {0, {this}}; } const_iterator end() const { return {this->size(), {this}}; } template std::vector to_vector() const { return std::vector(this->begin(), this->end()); } friend std::ostream& operator<<(std::ostream& os, const tensor_view& x) { if(!x.empty()) { os << as_number(x.front()); for(std::size_t i = 1; i < x.m_shape.elements(); i++) { os << ", " << as_number(x.m_data[x.m_shape.index(i)]); } } return os; } private: T* m_data; shape m_shape; }; template bool operator==(const tensor_view& x, const tensor_view& y) { if(x.get_shape() == y.get_shape()) { for(std::size_t i = 0; i < x.get_shape().elements(); i++) { if(!float_equal(x[i], y[i])) return false; } return true; } return false; } template bool operator!=(const tensor_view& x, const tensor_view& y) { return !(x == y); } template tensor_view make_view(const shape& s, T* data) { return {s, data}; } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif