Commit ee03ee68 authored by wooway777's avatar wooway777
Browse files

Fixed a double free bug

Fixed a bug where allocating 0 workspace causes overlapping memory in memory pool which then results in a memory double free error.
parent a73433ab
...@@ -15,9 +15,13 @@ MemoryPool::~MemoryPool() { ...@@ -15,9 +15,13 @@ MemoryPool::~MemoryPool() {
} }
void *MemoryPool::alloc(size_t size) { void *MemoryPool::alloc(size_t size) {
if (size == 0) {
return nullptr;
}
auto it = _free_blocks.lower_bound(size); auto it = _free_blocks.lower_bound(size);
if (it == _free_blocks.end()) { if (it == _free_blocks.end()) {
allocateNewRegion(std::max(size, size_t(0))); allocateNewRegion(size);
it = _free_blocks.lower_bound(size); it = _free_blocks.lower_bound(size);
if (it == _free_blocks.end()) { if (it == _free_blocks.end()) {
throw std::bad_alloc(); throw std::bad_alloc();
...@@ -51,6 +55,10 @@ void *MemoryPool::alloc(size_t size) { ...@@ -51,6 +55,10 @@ void *MemoryPool::alloc(size_t size) {
} }
void MemoryPool::release(void *ptr) { void MemoryPool::release(void *ptr) {
if (ptr == nullptr) {
return;
}
auto it = _ptr_to_block.find(ptr); auto it = _ptr_to_block.find(ptr);
if (it == _ptr_to_block.end()) { if (it == _ptr_to_block.end()) {
throw std::runtime_error("Invalid pointer to free"); throw std::runtime_error("Invalid pointer to free");
......
...@@ -9,6 +9,9 @@ ...@@ -9,6 +9,9 @@
#include <vector> #include <vector>
class Storage { class Storage {
private:
Storage() = default;
public: public:
void *memory; void *memory;
size_t size; size_t size;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "../tensor.hpp" #include "../tensor.hpp"
std::shared_ptr<Storage> Storage::create(size_t size) { std::shared_ptr<Storage> Storage::create(size_t size) {
auto storage = std::make_shared<Storage>(); auto storage = std::shared_ptr<Storage>(new Storage());
RUN_INFINI(infinirtMalloc(&storage->memory, size)); RUN_INFINI(infinirtMalloc(&storage->memory, size));
storage->size = size; storage->size = size;
RUN_INFINI(infinirtGetDevice(&storage->device_type, &storage->device_id)); RUN_INFINI(infinirtGetDevice(&storage->device_type, &storage->device_id));
...@@ -10,7 +10,7 @@ std::shared_ptr<Storage> Storage::create(size_t size) { ...@@ -10,7 +10,7 @@ std::shared_ptr<Storage> Storage::create(size_t size) {
} }
std::shared_ptr<Storage> Storage::createAsync(size_t size, infinirtStream_t stream) { std::shared_ptr<Storage> Storage::createAsync(size_t size, infinirtStream_t stream) {
auto storage = std::make_shared<Storage>(); auto storage = std::shared_ptr<Storage>(new Storage());
RUN_INFINI(infinirtMallocAsync(&storage->memory, size, stream)); RUN_INFINI(infinirtMallocAsync(&storage->memory, size, stream));
storage->size = size; storage->size = size;
RUN_INFINI(infinirtGetDevice(&storage->device_type, &storage->device_id)); RUN_INFINI(infinirtGetDevice(&storage->device_type, &storage->device_id));
...@@ -18,7 +18,7 @@ std::shared_ptr<Storage> Storage::createAsync(size_t size, infinirtStream_t stre ...@@ -18,7 +18,7 @@ std::shared_ptr<Storage> Storage::createAsync(size_t size, infinirtStream_t stre
} }
std::shared_ptr<Storage> Storage::createFromPool(size_t size, std::shared_ptr<MemoryPool> pool) { std::shared_ptr<Storage> Storage::createFromPool(size_t size, std::shared_ptr<MemoryPool> pool) {
auto storage = std::make_shared<Storage>(); auto storage = std::shared_ptr<Storage>(new Storage());
storage->memory_pool = pool; storage->memory_pool = pool;
if (pool) { if (pool) {
storage->memory = pool->alloc(size); storage->memory = pool->alloc(size);
...@@ -31,7 +31,7 @@ std::shared_ptr<Storage> Storage::createFromPool(size_t size, std::shared_ptr<Me ...@@ -31,7 +31,7 @@ std::shared_ptr<Storage> Storage::createFromPool(size_t size, std::shared_ptr<Me
} }
std::shared_ptr<Storage> Storage::createHost(size_t size) { std::shared_ptr<Storage> Storage::createHost(size_t size) {
auto storage = std::make_shared<Storage>(); auto storage = std::shared_ptr<Storage>(new Storage());
RUN_INFINI(infinirtMallocHost(&storage->memory, size)); RUN_INFINI(infinirtMallocHost(&storage->memory, size));
storage->size = size; storage->size = size;
storage->device_type = INFINI_DEVICE_CPU; storage->device_type = INFINI_DEVICE_CPU;
...@@ -41,11 +41,11 @@ std::shared_ptr<Storage> Storage::createHost(size_t size) { ...@@ -41,11 +41,11 @@ std::shared_ptr<Storage> Storage::createHost(size_t size) {
} }
Storage::~Storage() { Storage::~Storage() {
if (device_type == INFINI_DEVICE_CPU) { if (memory_pool) {
RUN_INFINI(infinirtFreeHost(memory)); memory_pool->release(memory);
} else { } else {
if (memory_pool) { if (device_type == INFINI_DEVICE_CPU) {
memory_pool->release(memory); RUN_INFINI(infinirtFreeHost(memory));
} else { } else {
RUN_INFINI(infinirtFree(memory)); RUN_INFINI(infinirtFree(memory));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment