Commit 49ab4a51 authored by Ville Pietilä's avatar Ville Pietilä
Browse files

Fixed problems with memory fragmentation.

parent fedb9211
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "ck/utility/env.hpp" #include "ck/utility/env.hpp"
#include <map> #include <map>
#include <queue> #include <queue>
#include <stack>
#include <mutex> #include <mutex>
#include <cstddef> #include <cstddef>
#include <limits> #include <limits>
...@@ -31,9 +32,10 @@ namespace memory { ...@@ -31,9 +32,10 @@ namespace memory {
class DynamicMemPool : public IMemPool class DynamicMemPool : public IMemPool
{ {
public: public:
DynamicMemPool() : DynamicMemPool(size_t maxPoolSizeInBytes = defaultMaxMemoryPoolSizeInBytes_) :
enableLogging_(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))), enableLogging_(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))),
pid_(getpid()) pid_(getpid()),
maxPoolSizeInBytes_(maxPoolSizeInBytes)
{ {
if (enableLogging_) if (enableLogging_)
std::cout << "[ DynamicMemPool ] Created memory pool for process " << pid_ << std::endl; std::cout << "[ DynamicMemPool ] Created memory pool for process " << pid_ << std::endl;
...@@ -97,7 +99,7 @@ namespace memory { ...@@ -97,7 +99,7 @@ namespace memory {
q.push(p); q.push(p);
memPoolSizeInBytes_ += sizeInBytes; memPoolSizeInBytes_ += sizeInBytes;
// If the memory pool size exceeds the maximum size, free the memory. // If the memory pool size exceeds the maximum size, free the memory.
if (memPoolSizeInBytes_ > maxMemoryPoolSizeInBytes_) if (memPoolSizeInBytes_ > maxPoolSizeInBytes_)
{ {
if (enableLogging_) if (enableLogging_)
{ {
...@@ -123,7 +125,7 @@ namespace memory { ...@@ -123,7 +125,7 @@ namespace memory {
} }
} }
private: private:
constexpr static size_t maxMemoryPoolSizeInBytes_ = 100 * 1024 * 1024; // 100MB constexpr static size_t defaultMaxMemoryPoolSizeInBytes_ = 100 * 1024 * 1024; // 100MB
static void clearMemoryPoolQueue(std::queue<void*>& q) static void clearMemoryPoolQueue(std::queue<void*>& q)
{ {
...@@ -140,27 +142,30 @@ namespace memory { ...@@ -140,27 +142,30 @@ namespace memory {
size_t memPoolSizeInBytes_{0}; size_t memPoolSizeInBytes_{0};
bool enableLogging_{false}; bool enableLogging_{false};
int pid_{-1}; int pid_{-1};
size_t maxPoolSizeInBytes_;
}; };
class StaticMemPool : public IMemPool class StaticMemPool : public IMemPool
{ {
public: public:
StaticMemPool() : StaticMemPool(size_t poolSizeInBytes = defaultMaxMemoryPoolSizeInBytes_) :
enableLogging_(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))), enableLogging_(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))),
pid_(getpid()), pid_(getpid()),
offsetInBytes_(0), offsetInBytes_(0),
preferRecycledMem_(ck::EnvIsEnabled(CK_ENV(CK_PREFER_RECYCLED_PINNED_MEM))) preferRecycledMem_(ck::EnvIsEnabled(CK_ENV(CK_PREFER_RECYCLED_PINNED_MEM))),
memoryPoolSizeInBytes_(poolSizeInBytes)
{ {
hip_check_error(hipHostMalloc(&pinnedMemoryBaseAddress_, memoryPoolSizeInBytes_)); allocateNewPinnedMemoryBlock();
if (enableLogging_)
{
std::cout << "[ StaticMemPool ] Created memory pool with " << memoryPoolSizeInBytes_ << " bytes for process " << pid_ << std::endl;
}
} }
~StaticMemPool() override ~StaticMemPool() override
{ {
hip_check_error(hipHostFree(pinnedMemoryBaseAddress_)); // Loop through all the pinned memory blocks and free them.
while (!pinnedMemoryBaseAddress_.empty())
{
hip_check_error(hipHostFree(pinnedMemoryBaseAddress_.top()));
pinnedMemoryBaseAddress_.pop();
}
if (enableLogging_) if (enableLogging_)
{ {
std::cout << "[ StaticMemPool ] Deleted pool for process " << pid_ << std::endl; std::cout << "[ StaticMemPool ] Deleted pool for process " << pid_ << std::endl;
...@@ -171,7 +176,7 @@ namespace memory { ...@@ -171,7 +176,7 @@ namespace memory {
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (!preferRecycledMem_ && offsetInBytes_ + sizeInBytes < memoryPoolSizeInBytes_) if (!preferRecycledMem_ && offsetInBytes_ + sizeInBytes - 1 < memoryPoolSizeInBytes_)
{ {
return allocateNewMemory(sizeInBytes); return allocateNewMemory(sizeInBytes);
} }
...@@ -182,16 +187,15 @@ namespace memory { ...@@ -182,16 +187,15 @@ namespace memory {
return ptr; return ptr;
} }
if (offsetInBytes_ + sizeInBytes < memoryPoolSizeInBytes_) if (offsetInBytes_ + sizeInBytes - 1 < memoryPoolSizeInBytes_)
{ {
return allocateNewMemory(sizeInBytes); return allocateNewMemory(sizeInBytes);
} }
if (enableLogging_) // Memory became too fragmented, reserve a new block.
{ // This should not happen very often.
std::cerr << "[ StaticMemPool ] Memory pool exausted." << std::endl; allocateNewPinnedMemoryBlock();
} return allocateNewMemory(sizeInBytes);
throw std::runtime_error("[ StaticMemPool ] Memory pool exausted.");
} }
void deallocate(void* p, std::size_t sizeInBytes) override void deallocate(void* p, std::size_t sizeInBytes) override
...@@ -200,42 +204,57 @@ namespace memory { ...@@ -200,42 +204,57 @@ namespace memory {
if (memory_pool_.find(sizeInBytes) != memory_pool_.end()) if (memory_pool_.find(sizeInBytes) != memory_pool_.end())
{ {
if (enableLogging_)
{
std::cout << "[ StaticMemPool ] Deallocate: Adding memory to pool for size " << sizeInBytes << std::endl;
}
auto& q = memory_pool_[sizeInBytes]; auto& q = memory_pool_[sizeInBytes];
q.push(p); q.push(p);
}
else {
if (enableLogging_) if (enableLogging_)
{ {
std::cout << "[ StaticMemPool ] Deallocate: Creating new pool queue for size " << sizeInBytes << std::endl; std::cout << "[ StaticMemPool ] Deallocate: Added memory to back to pool for size " << sizeInBytes <<
", pool has now " << q.size() << " elements." << std::endl;
} }
}
else {
std::queue<void*> q; std::queue<void*> q;
q.push(p); q.push(p);
memory_pool_.insert(std::make_pair(sizeInBytes, std::move(q))); memory_pool_.insert(std::make_pair(sizeInBytes, std::move(q)));
if (enableLogging_)
{
std::cout << "[ StaticMemPool ] Deallocate: Created new pool for size " << sizeInBytes << std::endl;
}
} }
} }
private: private:
constexpr static size_t memoryPoolSizeInBytes_ = 10 * 1024 * 1024; // 10MB constexpr static size_t defaultMaxMemoryPoolSizeInBytes_ = 10 * 1024 * 1024; // 10MB
std::mutex mutex_; // Mutex to protect access to the memory pool. std::mutex mutex_; // Mutex to protect access to the memory pool.
std::map<size_t, std::queue<void*>> memory_pool_{}; std::map<size_t, std::queue<void*>> memory_pool_{};
std::byte* pinnedMemoryBaseAddress_; std::stack<std::byte*> pinnedMemoryBaseAddress_;
bool enableLogging_; bool enableLogging_;
int pid_; int pid_;
int offsetInBytes_; int offsetInBytes_;
bool preferRecycledMem_; bool preferRecycledMem_;
size_t memoryPoolSizeInBytes_;
void allocateNewPinnedMemoryBlock()
{
std::byte* pinnedMemoryBaseAddress;
hip_check_error(hipHostMalloc(&pinnedMemoryBaseAddress, memoryPoolSizeInBytes_));
pinnedMemoryBaseAddress_.push(pinnedMemoryBaseAddress);
offsetInBytes_ = 0;
if (enableLogging_)
{
std::cout << "[ StaticMemPool ] Allocation: created new pinned memory block of " << memoryPoolSizeInBytes_ << " bytes." << std::endl;
}
}
void* allocateNewMemory(size_t sizeInBytes) void* allocateNewMemory(size_t sizeInBytes)
{ {
// Return new memory from the preallocated block // Return new memory from the preallocated block
void* p = pinnedMemoryBaseAddress_ + offsetInBytes_; void* p = pinnedMemoryBaseAddress_.top() + offsetInBytes_;
offsetInBytes_ += sizeInBytes; offsetInBytes_ += sizeInBytes;
if (enableLogging_) if (enableLogging_)
{ {
const auto pct = 100.0f * static_cast<float>(offsetInBytes_) / memoryPoolSizeInBytes_; const auto pct = 100.0f * static_cast<float>(offsetInBytes_) / memoryPoolSizeInBytes_;
std::cout << "[ StaticMemPool ] Allocation: return new memory, pinned host memory usage: " << pct << "%." << std::endl; std::cout << "[ StaticMemPool ] Allocation: return new memory of " << sizeInBytes <<
" bytes, pinned host memory usage: " << pct << "%." << std::endl;
} }
return p; return p;
} }
...@@ -255,27 +274,31 @@ namespace memory { ...@@ -255,27 +274,31 @@ namespace memory {
} }
// Try to find memory from the queue that is nearest in size. // Try to find memory from the queue that is nearest in size.
std::pair<size_t, std::queue<void*>> nearest_queue = {std::numeric_limits<size_t>::max(), std::queue<void*>()}; size_t nearest_queue_size = std::numeric_limits<size_t>::max();
for (auto& [size, q] : memory_pool_) for (auto& [size, q] : memory_pool_)
{ {
if (size > sizeInBytes && !q.empty() && size < nearest_queue.first) if (size > sizeInBytes && !q.empty() && size < nearest_queue_size)
{ {
nearest_queue = {size, q}; nearest_queue_size = size;
} }
} }
if (nearest_queue.first != std::numeric_limits<size_t>::max()) if (nearest_queue_size != std::numeric_limits<size_t>::max())
{ {
auto& nearest_queue = memory_pool_[nearest_queue_size];
void* p = nearest_queue.front();
nearest_queue.pop();
if (enableLogging_) if (enableLogging_)
{ {
std::cout << "[ StaticMemPool ] Allocation: reusing memory from pool for size " << nearest_queue.first << std::cout << "[ StaticMemPool ] Allocation: reusing memory from pool for size " << nearest_queue_size <<
" to allocate " << sizeInBytes << "bytes" <<std::endl; " to allocate " << sizeInBytes << " bytes, pool has " << nearest_queue.size() << " elements." <<
std::endl;
} }
void* p = nearest_queue.second.front();
nearest_queue.second.pop();
return p; return p;
} }
std::cerr << "[ StaticMemPool ] WARNING: Could not find memory from pool to allocate " << sizeInBytes <<
" bytes." << std::endl;
return nullptr; return nullptr;
} }
}; };
...@@ -283,7 +306,7 @@ namespace memory { ...@@ -283,7 +306,7 @@ namespace memory {
class PinnedHostMemoryAllocatorBase class PinnedHostMemoryAllocatorBase
{ {
protected: protected:
static IMemPool* get_memory_pool() { virtual IMemPool* get_memory_pool() {
static DynamicMemPool dynamic_memory_pool; static DynamicMemPool dynamic_memory_pool;
static StaticMemPool static_memory_pool; static StaticMemPool static_memory_pool;
static bool use_dynamic_mem_pool = ck::EnvIsEnabled(CK_ENV(CK_USE_DYNAMIC_MEM_POOL)); static bool use_dynamic_mem_pool = ck::EnvIsEnabled(CK_ENV(CK_USE_DYNAMIC_MEM_POOL));
...@@ -343,6 +366,10 @@ namespace memory { ...@@ -343,6 +366,10 @@ namespace memory {
void destroy(U* p) noexcept { void destroy(U* p) noexcept {
p->~U(); p->~U();
} }
protected:
IMemPool* get_memory_pool() override {
return PinnedHostMemoryAllocatorBase::get_memory_pool();
}
}; };
template <typename T, typename U> template <typename T, typename U>
......
...@@ -9,7 +9,64 @@ ...@@ -9,7 +9,64 @@
using namespace ck::memory; using namespace ck::memory;
TEST(UtilityTests, PinnedHostMemoryAllocator_recycle_pinned_host_memory) namespace
{
class TestMemoryAllocator : public PinnedHostMemoryAllocator<std::byte>
{
public:
TestMemoryAllocator() : PinnedHostMemoryAllocator()
{
}
protected:
IMemPool* get_memory_pool() override {
static StaticMemPool pool(maxMemoryPoolSizeInBytes_);
throw std::runtime_error("Static memory pool should not be used.");
return &pool;
}
private:
static constexpr size_t maxMemoryPoolSizeInBytes_ = 10;
};
}
TEST(UtilityTests, StaticMemoryPool_test_memory_allocation)
{
const size_t size1 = 8;
const size_t size2 = 2;
std::byte *ptr1, *ptr2;
StaticMemPool pool(size1 + size2);
ptr1 = static_cast<std::byte*>(pool.allocate(size1));
ptr2 = static_cast<std::byte*>(pool.allocate(size2));
EXPECT_TRUE(ptr1 != nullptr);
EXPECT_TRUE(ptr2 != nullptr);
pool.deallocate(ptr1, size1);
pool.deallocate(ptr2, size2);
std::byte* ptr3 = static_cast<std::byte*>(pool.allocate(size2));
std::byte* ptr4 = static_cast<std::byte*>(pool.allocate(size1));
EXPECT_TRUE(ptr3 != nullptr);
EXPECT_TRUE(ptr4 != nullptr);
EXPECT_TRUE(ptr3 != ptr4);
EXPECT_TRUE(ptr3 == ptr2);
EXPECT_TRUE(ptr4 == ptr1);
pool.deallocate(ptr3, size2);
pool.deallocate(ptr4, size1);
const size_t size3 = 6;
const size_t size4 = 4;
std::byte* ptr5 = static_cast<std::byte*>(pool.allocate(size3));
std::byte* ptr6 = static_cast<std::byte*>(pool.allocate(size4));
EXPECT_TRUE(ptr5 != nullptr);
EXPECT_TRUE(ptr6 != nullptr);
pool.deallocate(ptr5, size3);
pool.deallocate(ptr6, size4);
}
TEST(UtilityTests, PinnedHostMemoryAllocator_new_memory_is_allocated)
{ {
const size_t vSize = 10; const size_t vSize = 10;
int* ptr1; int* ptr1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment