Commit bfae3bbb authored by wooway777's avatar wooway777
Browse files

issue/21 - made TensorDesc const

parent 21ef8820
...@@ -86,7 +86,7 @@ public: ...@@ -86,7 +86,7 @@ public:
class Tensor : public std::enable_shared_from_this<Tensor> { class Tensor : public std::enable_shared_from_this<Tensor> {
private: private:
std::shared_ptr<Storage> _storage; std::shared_ptr<Storage> _storage;
std::shared_ptr<TensorDesc> _desc; std::shared_ptr<const TensorDesc> _desc;
ptrdiff_t _offset; ptrdiff_t _offset;
......
...@@ -67,8 +67,14 @@ void TensorDesc::dimMerge(size_t dim_start, size_t dim_end) { ...@@ -67,8 +67,14 @@ void TensorDesc::dimMerge(size_t dim_start, size_t dim_end) {
} }
std::shared_ptr<Tensor> Tensor::dimMerge(size_t dim_start, size_t dim_end) { std::shared_ptr<Tensor> Tensor::dimMerge(size_t dim_start, size_t dim_end) {
this->_desc->dimMerge(dim_start, dim_end); auto new_desc = TensorDesc::create(_desc->dtype(), _desc->shape(), _desc->strides());
return shared_from_this(); new_desc->dimMerge(dim_start, dim_end);
auto tensor = std::make_shared<Tensor>();
tensor->_storage = _storage;
tensor->_desc = new_desc;
tensor->_offset = _offset;
return tensor;
} }
void TensorDesc::dimSplit(size_t dim, const std::vector<size_t> &dims) { void TensorDesc::dimSplit(size_t dim, const std::vector<size_t> &dims) {
...@@ -94,8 +100,14 @@ void TensorDesc::dimSplit(size_t dim, const std::vector<size_t> &dims) { ...@@ -94,8 +100,14 @@ void TensorDesc::dimSplit(size_t dim, const std::vector<size_t> &dims) {
} }
std::shared_ptr<Tensor> Tensor::dimSplit(size_t dim, const std::vector<size_t> &dims) { std::shared_ptr<Tensor> Tensor::dimSplit(size_t dim, const std::vector<size_t> &dims) {
this->_desc->dimSplit(dim, dims); auto new_desc = TensorDesc::create(_desc->dtype(), _desc->shape(), _desc->strides());
return shared_from_this(); new_desc->dimSplit(dim, dims);
auto tensor = std::make_shared<Tensor>();
tensor->_storage = _storage;
tensor->_desc = new_desc;
tensor->_offset = _offset;
return tensor;
} }
void TensorDesc::permute(const std::vector<size_t> &order) { void TensorDesc::permute(const std::vector<size_t> &order) {
...@@ -114,6 +126,12 @@ void TensorDesc::permute(const std::vector<size_t> &order) { ...@@ -114,6 +126,12 @@ void TensorDesc::permute(const std::vector<size_t> &order) {
} }
std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) { std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
this->_desc->permute(order); auto new_desc = TensorDesc::create(_desc->dtype(), _desc->shape(), _desc->strides());
return shared_from_this(); new_desc->permute(order);
auto tensor = std::make_shared<Tensor>();
tensor->_storage = _storage;
tensor->_desc = new_desc;
tensor->_offset = _offset;
return tensor;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment