Commit c1535ae8 authored by PanZezhong's avatar PanZezhong
Browse files

issue/810 feat: allow graph tensor to resume to allocator's tracking

parent 7c978949
...@@ -12,6 +12,7 @@ class GraphManager; ...@@ -12,6 +12,7 @@ class GraphManager;
class GraphTensor : public Tensor { class GraphTensor : public Tensor {
public: public:
GraphTensor(const Tensor &); GraphTensor(const Tensor &);
void resume() const;
}; };
class GraphOperator { class GraphOperator {
......
...@@ -90,6 +90,8 @@ protected: ...@@ -90,6 +90,8 @@ protected:
Tensor(std::shared_ptr<TensorImpl> impl) : impl_(std::move(impl)) {} Tensor(std::shared_ptr<TensorImpl> impl) : impl_(std::move(impl)) {}
std::shared_ptr<TensorImpl> impl_; std::shared_ptr<TensorImpl> impl_;
friend class TensorImpl; friend class TensorImpl;
void resume_from_blob_() const;
}; };
class TensorImpl : public std::enable_shared_from_this<TensorImpl> { class TensorImpl : public std::enable_shared_from_this<TensorImpl> {
......
...@@ -125,6 +125,16 @@ void PinnableBlockAllocator::deallocate(std::byte *ptr) { ...@@ -125,6 +125,16 @@ void PinnableBlockAllocator::deallocate(std::byte *ptr) {
} }
} }
size_t PinnableBlockAllocator::mark_in_use_(void *ptr, bool in_use) {
auto it = all_blocks_.find(reinterpret_cast<void *>(ptr));
if (it == all_blocks_.end()) {
throw std::runtime_error("Pointer not allocated by this allocator");
}
std::lock_guard<std::mutex> lock(mutex_);
it->second->in_use = in_use;
return it->second->size;
}
// ------------------- trim ------------------- // ------------------- trim -------------------
void PinnableBlockAllocator::trim() { void PinnableBlockAllocator::trim() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
......
...@@ -32,6 +32,10 @@ public: ...@@ -32,6 +32,10 @@ public:
// Switch pinned/graph mode // Switch pinned/graph mode
void set_pin_mode(bool pinned) { pinned_mode_ = pinned; } void set_pin_mode(bool pinned) { pinned_mode_ = pinned; }
// internal use only, force set in_use flag for a mem block
// return the size of the block
size_t mark_in_use_(void *ptr, bool in_use);
// trim cached blocks back to GPU (not pinned) // trim cached blocks back to GPU (not pinned)
void trim(); void trim();
......
#include "context_impl.hpp" #include "context_impl.hpp"
#include "internal.hpp"
#include "../utils.hpp" #include "../utils.hpp"
...@@ -194,6 +195,12 @@ void addGraphOperator(std::shared_ptr<graph::GraphOperator> op) { ...@@ -194,6 +195,12 @@ void addGraphOperator(std::shared_ptr<graph::GraphOperator> op) {
std::shared_ptr<graph::Graph> stopGraphRecording() { std::shared_ptr<graph::Graph> stopGraphRecording() {
return ContextImpl::singleton().getCurrentRuntime()->stopGraphRecording(); return ContextImpl::singleton().getCurrentRuntime()->stopGraphRecording();
} }
std::shared_ptr<Memory> reinstantiateBlob(std::shared_ptr<Memory> blob) {
setDevice(blob->device());
return ContextImpl::singleton().getCurrentRuntime()->reinstantiateBlob(blob);
}
} // namespace context } // namespace context
} // namespace infinicore } // namespace infinicore
#pragma once
#include "infinicore/device.hpp"
#include "infinicore/memory.hpp"
#include "infinicore/graph/graph.hpp"
namespace infinicore::context {
std::shared_ptr<Memory> reinstantiateBlob(std::shared_ptr<Memory> blob);
};
...@@ -77,6 +77,15 @@ std::shared_ptr<Memory> Runtime::allocatePinnedHostMemory(size_t size) { ...@@ -77,6 +77,15 @@ std::shared_ptr<Memory> Runtime::allocatePinnedHostMemory(size_t size) {
true); true);
} }
std::shared_ptr<Memory> Runtime::reinstantiateBlob(std::shared_ptr<Memory> blob) {
device_memory_allocator_.get()->mark_in_use_(blob->data(), true);
return std::make_shared<Memory>(
blob->data(), blob->size(), device_,
[alloc = device_memory_allocator_.get()](std::byte *p) {
alloc->deallocate(p);
});
}
void Runtime::memcpyH2D(void *dst, const void *src, size_t size, bool async) { void Runtime::memcpyH2D(void *dst, const void *src, size_t size, bool async) {
if (async) { if (async) {
INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_H2D, stream_)); INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_H2D, stream_));
......
...@@ -37,6 +37,7 @@ public: ...@@ -37,6 +37,7 @@ public:
std::shared_ptr<Memory> allocateMemory(size_t size); std::shared_ptr<Memory> allocateMemory(size_t size);
std::shared_ptr<Memory> allocatePinnedHostMemory(size_t size); std::shared_ptr<Memory> allocatePinnedHostMemory(size_t size);
std::shared_ptr<Memory> reinstantiateBlob(std::shared_ptr<Memory> blob);
void memcpyH2D(void *dst, const void *src, size_t size, bool async = true); void memcpyH2D(void *dst, const void *src, size_t size, bool async = true);
void memcpyD2H(void *dst, const void *src, size_t size); void memcpyD2H(void *dst, const void *src, size_t size);
......
...@@ -11,6 +11,10 @@ namespace infinicore::graph { ...@@ -11,6 +11,10 @@ namespace infinicore::graph {
GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob()) { GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob()) {
} }
void GraphTensor::resume() const {
resume_from_blob_();
}
/* ========================= /* =========================
* GraphOperator * GraphOperator
* ========================= */ * ========================= */
......
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include "../context/internal.hpp"
#include "../utils.hpp" #include "../utils.hpp"
#include "infinicore/context/context.hpp" #include "infinicore/context/context.hpp"
#include "infinicore/dtype.hpp" #include "infinicore/dtype.hpp"
...@@ -64,6 +65,10 @@ Tensor::operator bool() const { ...@@ -64,6 +65,10 @@ Tensor::operator bool() const {
return impl_ != nullptr; return impl_ != nullptr;
} }
void Tensor::resume_from_blob_() const {
context::reinstantiateBlob(impl_->data_.memory);
}
TensorMetaData::TensorMetaData(const Shape &_shape, const Strides &_strides, const DataType &_dtype) TensorMetaData::TensorMetaData(const Shape &_shape, const Strides &_strides, const DataType &_dtype)
: shape(_shape), strides(_strides), dtype(_dtype) { : shape(_shape), strides(_strides), dtype(_dtype) {
INFINICORE_CHECK_ERROR(infiniopCreateTensorDescriptor(&desc, shape.size(), shape.data(), strides.data(), (infiniDtype_t)dtype)); INFINICORE_CHECK_ERROR(infiniopCreateTensorDescriptor(&desc, shape.size(), shape.data(), strides.data(), (infiniDtype_t)dtype));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment