Unverified Commit 3c8fb3c0 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #967 from InfiniTensor/issue/811

issue/811 fix tensor to blob and resume
parents f00c06d0 90cc3bdd
...@@ -12,7 +12,6 @@ class GraphManager; ...@@ -12,7 +12,6 @@ class GraphManager;
class GraphTensor : public Tensor { class GraphTensor : public Tensor {
public: public:
GraphTensor(const Tensor &); GraphTensor(const Tensor &);
void resume() const;
}; };
class GraphOperator { class GraphOperator {
......
...@@ -90,8 +90,6 @@ protected: ...@@ -90,8 +90,6 @@ protected:
Tensor(std::shared_ptr<TensorImpl> impl) : impl_(std::move(impl)) {} Tensor(std::shared_ptr<TensorImpl> impl) : impl_(std::move(impl)) {}
std::shared_ptr<TensorImpl> impl_; std::shared_ptr<TensorImpl> impl_;
friend class TensorImpl; friend class TensorImpl;
void resume_from_blob_() const;
}; };
class TensorImpl : public std::enable_shared_from_this<TensorImpl> { class TensorImpl : public std::enable_shared_from_this<TensorImpl> {
...@@ -135,7 +133,18 @@ public: ...@@ -135,7 +133,18 @@ public:
void debug() const; void debug() const;
Tensor to_blob() const; /**
* Unsafe API that returns a new tensor with the same raw memory untracked by allocator
* This API is used for loosely tracking a piece of memory while allowing it to be reused,
* typically in a compute graph scenario.
*/
Tensor to_blob_() const;
/**
* Unsafe API that returns a new tensor with the same memory and let allocator retracks the memory.
* Should only be used on the tensor returned by to_blob_().
*/
Tensor resume_from_blob_() const;
/// ///
/// Data Transfer APIs /// Data Transfer APIs
...@@ -301,6 +310,10 @@ protected: ...@@ -301,6 +310,10 @@ protected:
protected: protected:
TensorMetaData meta_; TensorMetaData meta_;
TensorData data_; TensorData data_;
private:
// Mark to indicate if the tensor is created from to_blob_()
bool to_blob_mark_ = false;
}; };
} // namespace infinicore } // namespace infinicore
...@@ -52,9 +52,19 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) { ...@@ -52,9 +52,19 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) {
if (size <= cls.block_size) { if (size <= cls.block_size) {
if (!cls.free_blocks.empty()) { if (!cls.free_blocks.empty()) {
block = cls.free_blocks.back(); block = cls.free_blocks.back();
cls.free_blocks.pop_back(); while (block != nullptr && block->in_use) {
block->in_use = true; cls.free_blocks.pop_back();
return reinterpret_cast<std::byte *>(block->ptr); if (cls.free_blocks.empty()) {
block = nullptr;
break;
}
block = cls.free_blocks.back();
}
if (block != nullptr) {
cls.free_blocks.pop_back();
block->in_use = true;
return reinterpret_cast<std::byte *>(block->ptr);
}
} }
// Allocate a new block for this class // Allocate a new block for this class
block = std::make_shared<Block>(); block = std::make_shared<Block>();
......
...@@ -10,11 +10,7 @@ namespace infinicore::graph { ...@@ -10,11 +10,7 @@ namespace infinicore::graph {
* GraphTensor * GraphTensor
* ========================= */ * ========================= */
GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob()) { GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob_()) {
}
void GraphTensor::resume() const {
resume_from_blob_();
} }
/* ========================= /* =========================
......
...@@ -38,7 +38,7 @@ void TensorImpl::copy_from(Tensor src) { ...@@ -38,7 +38,7 @@ void TensorImpl::copy_from(Tensor src) {
} else { } else {
auto local_src = Tensor::empty(this->shape(), this->dtype(), this->device()); auto local_src = Tensor::empty(this->shape(), this->dtype(), this->device());
context::setDevice(src->device()); context::setDevice(src->device());
context::memcpyD2H(local_src->data(), src->data(), this->data_.memory->size()); context::memcpyD2H(local_src->data(), src->data(), copy_size);
op::rearrange_(Tensor(const_cast<TensorImpl *>(this)->shared_from_this()), local_src); op::rearrange_(Tensor(const_cast<TensorImpl *>(this)->shared_from_this()), local_src);
} }
} else if (src->device().getType() == Device::Type::CPU) { } else if (src->device().getType() == Device::Type::CPU) {
......
...@@ -65,10 +65,6 @@ Tensor::operator bool() const { ...@@ -65,10 +65,6 @@ Tensor::operator bool() const {
return impl_ != nullptr; return impl_ != nullptr;
} }
void Tensor::resume_from_blob_() const {
context::reinstantiateBlob(impl_->data_.memory);
}
TensorMetaData::TensorMetaData(const Shape &_shape, const Strides &_strides, const DataType &_dtype) TensorMetaData::TensorMetaData(const Shape &_shape, const Strides &_strides, const DataType &_dtype)
: shape(_shape), strides(_strides), dtype(_dtype) { : shape(_shape), strides(_strides), dtype(_dtype) {
INFINICORE_CHECK_ERROR(infiniopCreateTensorDescriptor(&desc, shape.size(), shape.data(), strides.data(), (infiniDtype_t)dtype)); INFINICORE_CHECK_ERROR(infiniopCreateTensorDescriptor(&desc, shape.size(), shape.data(), strides.data(), (infiniDtype_t)dtype));
...@@ -280,10 +276,22 @@ std::shared_ptr<TensorImpl> TensorImpl::strided_from_blob( ...@@ -280,10 +276,22 @@ std::shared_ptr<TensorImpl> TensorImpl::strided_from_blob(
return t; return t;
} }
Tensor TensorImpl::to_blob() const { Tensor TensorImpl::to_blob_() const {
auto t = std::shared_ptr<TensorImpl>(new TensorImpl(shape(), strides(), dtype())); auto t = std::shared_ptr<TensorImpl>(new TensorImpl(shape(), strides(), dtype()));
t->data_.offset = this->data_.offset; t->data_.offset = this->data_.offset;
t->data_.memory = std::make_shared<Memory>(this->data_.memory->data(), this->data_.memory->size(), this->data_.memory->device(), nullptr); t->data_.memory = std::make_shared<Memory>(this->data_.memory->data(), this->data_.memory->size(), this->data_.memory->device(), nullptr);
t->to_blob_mark_ = true;
return Tensor{t};
}
Tensor TensorImpl::resume_from_blob_() const {
auto t = std::shared_ptr<TensorImpl>(new TensorImpl(shape(), strides(), dtype()));
t->data_.offset = this->data_.offset;
if (to_blob_mark_) {
t->data_.memory = context::reinstantiateBlob(this->data_.memory);
} else {
t->data_.memory = this->data_.memory;
}
return Tensor{t}; return Tensor{t};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment