#include "../tensor.hpp" #include "../utils.hpp" #include #include #include std::shared_ptr Tensor::sliceImpl(const std::vector &slices) const { std::shared_ptr tensor = std::make_shared(); auto new_shape = std::vector(this->shape()); ptrdiff_t offset = 0; for (const auto &slice : slices) { ASSERT(slice.len > 0); ASSERT(this->shape()[slice.dim] >= slice.start + slice.len); new_shape[slice.dim] = slice.len; offset += slice.start * this->strides()[slice.dim]; } tensor->_desc = TensorDesc::create(this->dtype(), new_shape, this->strides()); tensor->_offset = offset * dsize(this->dtype()) + this->_offset; tensor->_storage = this->_storage; return tensor; } std::shared_ptr Tensor::slice(size_t dim, size_t start, size_t len) { return this->sliceImpl({{dim, start, len}}); } std::shared_ptr Tensor::slice(size_t dim, size_t start, size_t len) const { return this->sliceImpl({{dim, start, len}}); } std::shared_ptr Tensor::slice(const std::vector &slices) { return this->sliceImpl(slices); } std::shared_ptr Tensor::slice(const std::vector &slices) const { return this->sliceImpl(slices); } void TensorDesc::dimMerge(size_t dim_start, size_t dim_end) { ASSERT(dim_start <= dim_end && dim_end < this->_shape.size()); if (dim_start == dim_end) { return; } auto new_shape = std::vector(); auto new_strides = std::vector(); for (size_t i = 0; i < dim_start; i++) { new_shape.push_back(this->_shape[i]); new_strides.push_back(this->_strides[i]); } for (size_t i = dim_start + 1; i <= dim_end; i++) { ASSERT_EQ(this->_strides[i - 1], ptrdiff_t(this->_shape[i]) * this->_strides[i]); } new_shape.push_back(std::accumulate(this->_shape.begin() + dim_start, this->_shape.begin() + dim_end + 1, 1, std::multiplies())); new_strides.push_back(this->_strides[dim_end]); for (size_t i = dim_end + 1; i < this->_shape.size(); i++) { new_shape.push_back(this->_shape[i]); new_strides.push_back(this->_strides[i]); } this->_shape = new_shape; this->_strides = new_strides; this->resetDesc(); this->computeTensorDesHash(); } std::shared_ptr Tensor::dimMerge(size_t dim_start, size_t dim_end) { auto new_desc = TensorDesc::create(_desc->dtype(), _desc->shape(), _desc->strides()); new_desc->dimMerge(dim_start, dim_end); auto tensor = std::make_shared(); tensor->_storage = _storage; tensor->_desc = new_desc; tensor->_offset = _offset; return tensor; } void TensorDesc::dimSplit(size_t dim, const std::vector &dims) { ASSERT_EQ(this->_shape[dim], std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies())); auto new_shape = std::vector(); auto new_strides = std::vector(); for (size_t i = 0; i < dim; i++) { new_shape.push_back(this->_shape[i]); new_strides.push_back(this->_strides[i]); } for (size_t i = 0; i < dims.size(); i++) { new_shape.push_back(dims[i]); new_strides.push_back(this->_strides[dim] * this->_shape[dim] / std::accumulate(dims.begin(), dims.begin() + i + 1, 1, std::multiplies())); } for (size_t i = dim + 1; i < this->_shape.size(); i++) { new_shape.push_back(this->_shape[i]); new_strides.push_back(this->_strides[i]); } this->_shape = new_shape; this->_strides = new_strides; this->resetDesc(); this->computeTensorDesHash(); } std::shared_ptr Tensor::dimSplit(size_t dim, const std::vector &dims) { auto new_desc = TensorDesc::create(_desc->dtype(), _desc->shape(), _desc->strides()); new_desc->dimSplit(dim, dims); auto tensor = std::make_shared(); tensor->_storage = _storage; tensor->_desc = new_desc; tensor->_offset = _offset; return tensor; } void TensorDesc::permute(const std::vector &order) { ASSERT_EQ(this->_shape.size(), order.size()); auto new_shape = std::vector(order.size()); auto new_strides = std::vector(order.size()); for (size_t i = 0; i < order.size(); i++) { ASSERT(std::find(order.begin(), order.end(), i) != order.end()); new_shape[i] = this->_shape[order[i]]; new_strides[i] = this->_strides[order[i]]; } this->_shape = new_shape; this->_strides = new_strides; this->resetDesc(); this->computeTensorDesHash(); } std::shared_ptr Tensor::permute(const std::vector &order) { auto new_desc = TensorDesc::create(_desc->dtype(), _desc->shape(), _desc->strides()); new_desc->permute(order); auto tensor = std::make_shared(); tensor->_storage = _storage; tensor->_desc = new_desc; tensor->_offset = _offset; return tensor; }