Commit c1535ae8 authored by PanZezhong's avatar PanZezhong
Browse files

issue/810 feat: allow graph tensor to resume to allocator's tracking

parent 7c978949
......@@ -12,6 +12,7 @@ class GraphManager;
class GraphTensor : public Tensor {
public:
GraphTensor(const Tensor &);
void resume() const;
};
class GraphOperator {
......
......@@ -90,6 +90,8 @@ protected:
Tensor(std::shared_ptr<TensorImpl> impl) : impl_(std::move(impl)) {}
std::shared_ptr<TensorImpl> impl_;
friend class TensorImpl;
void resume_from_blob_() const;
};
class TensorImpl : public std::enable_shared_from_this<TensorImpl> {
......
......@@ -125,6 +125,16 @@ void PinnableBlockAllocator::deallocate(std::byte *ptr) {
}
}
size_t PinnableBlockAllocator::mark_in_use_(void *ptr, bool in_use) {
auto it = all_blocks_.find(reinterpret_cast<void *>(ptr));
if (it == all_blocks_.end()) {
throw std::runtime_error("Pointer not allocated by this allocator");
}
std::lock_guard<std::mutex> lock(mutex_);
it->second->in_use = in_use;
return it->second->size;
}
// ------------------- trim -------------------
void PinnableBlockAllocator::trim() {
std::lock_guard<std::mutex> lock(mutex_);
......
......@@ -32,6 +32,10 @@ public:
// Switch pinned/graph mode
void set_pin_mode(bool pinned) { pinned_mode_ = pinned; }
// internal use only, force set in_use flag for a mem block
// return the size of the block
size_t mark_in_use_(void *ptr, bool in_use);
// trim cached blocks back to GPU (not pinned)
void trim();
......
#include "context_impl.hpp"
#include "internal.hpp"
#include "../utils.hpp"
......@@ -194,6 +195,12 @@ void addGraphOperator(std::shared_ptr<graph::GraphOperator> op) {
std::shared_ptr<graph::Graph> stopGraphRecording() {
return ContextImpl::singleton().getCurrentRuntime()->stopGraphRecording();
}
std::shared_ptr<Memory> reinstantiateBlob(std::shared_ptr<Memory> blob) {
setDevice(blob->device());
return ContextImpl::singleton().getCurrentRuntime()->reinstantiateBlob(blob);
}
} // namespace context
} // namespace infinicore
#pragma once
#include "infinicore/device.hpp"
#include "infinicore/memory.hpp"
#include "infinicore/graph/graph.hpp"
namespace infinicore::context {
std::shared_ptr<Memory> reinstantiateBlob(std::shared_ptr<Memory> blob);
};
......@@ -77,6 +77,15 @@ std::shared_ptr<Memory> Runtime::allocatePinnedHostMemory(size_t size) {
true);
}
std::shared_ptr<Memory> Runtime::reinstantiateBlob(std::shared_ptr<Memory> blob) {
device_memory_allocator_.get()->mark_in_use_(blob->data(), true);
return std::make_shared<Memory>(
blob->data(), blob->size(), device_,
[alloc = device_memory_allocator_.get()](std::byte *p) {
alloc->deallocate(p);
});
}
void Runtime::memcpyH2D(void *dst, const void *src, size_t size, bool async) {
if (async) {
INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_H2D, stream_));
......
......@@ -37,6 +37,7 @@ public:
std::shared_ptr<Memory> allocateMemory(size_t size);
std::shared_ptr<Memory> allocatePinnedHostMemory(size_t size);
std::shared_ptr<Memory> reinstantiateBlob(std::shared_ptr<Memory> blob);
void memcpyH2D(void *dst, const void *src, size_t size, bool async = true);
void memcpyD2H(void *dst, const void *src, size_t size);
......
......@@ -11,6 +11,10 @@ namespace infinicore::graph {
GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob()) {
}
void GraphTensor::resume() const {
resume_from_blob_();
}
/* =========================
* GraphOperator
* ========================= */
......
#include "infinicore/tensor.hpp"
#include "../context/internal.hpp"
#include "../utils.hpp"
#include "infinicore/context/context.hpp"
#include "infinicore/dtype.hpp"
......@@ -64,6 +65,10 @@ Tensor::operator bool() const {
return impl_ != nullptr;
}
void Tensor::resume_from_blob_() const {
context::reinstantiateBlob(impl_->data_.memory);
}
TensorMetaData::TensorMetaData(const Shape &_shape, const Strides &_strides, const DataType &_dtype)
: shape(_shape), strides(_strides), dtype(_dtype) {
INFINICORE_CHECK_ERROR(infiniopCreateTensorDescriptor(&desc, shape.size(), shape.data(), strides.data(), (infiniDtype_t)dtype));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment