Commit 870c3a76 authored by Ville Pietilä's avatar Ville Pietilä
Browse files

Refactor mempool implementation. Fix a bug in allocating pinned host memory.

parent 0c7b35c4
...@@ -12,117 +12,72 @@ ...@@ -12,117 +12,72 @@
namespace ck { namespace ck {
namespace memory { namespace memory {
class DebugStream
{
public:
DebugStream(bool enable_output = true) : enable_output_(enable_output) {}
template <typename T>
DebugStream& operator<<(const T& value)
{
if (enable_output_)
{
std::cout << value;
}
return *this;
}
// Overload for std::ostream manipulators like std::endl
using Manipulator = std::ostream& (*)(std::ostream&);
DebugStream& operator<<(Manipulator manip)
{
if (enable_output_)
{
manip(std::cout);
}
return *this;
}
void enableOutput(bool enable)
{
enable_output_ = enable;
}
private:
bool enable_output_;
};
template <typename T>
class MemPool class MemPool
{ {
public: public:
MemPool() : debug_stream(true) {} MemPool() = default;
~MemPool() ~MemPool()
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
debug_stream << "Destroying memory pool of type " << typeid(T).name() << std::endl;
for (auto& [size, q] : memory_pool_) for (auto& [size, q] : memory_pool_)
{ {
// Iterate through the queue and free the memory clearMemoryPoolQueue(q);
while (!q.empty())
{
T* p = q.front();
q.pop();
hip_check_error(hipHostFree(p));
}
} }
} }
T* allocate(std::size_t n) void* allocate(std::size_t sizeInBytes)
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
debug_stream << "Allocating size " << n << " for type " << typeid(T).name() << std::endl;
// If there is a memory pool for the requested size, return the memory from the pool. // If there is a memory pool for the requested size, return the memory from the pool.
if (memory_pool_.find(n) != memory_pool_.end() && !memory_pool_[n].empty()) if (memory_pool_.find(sizeInBytes) != memory_pool_.end() && !memory_pool_[sizeInBytes].empty())
{ {
debug_stream << "\tReturning from memory pool" << std::endl; void* p = memory_pool_[sizeInBytes].front();
T* p = memory_pool_[n].front(); memory_pool_[sizeInBytes].pop();
memory_pool_[n].pop();
return p; return p;
} }
debug_stream << "\tAllocating new memory" << std::endl; void* p;
T* p; hip_check_error(hipHostMalloc(&p, sizeInBytes));
hip_check_error(hipHostMalloc(&p, n));
return p; return p;
} }
void deallocate(T* p, std::size_t size) void deallocate(void* p, std::size_t sizeInBytes)
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (memory_pool_.find(size) != memory_pool_.end()) if (memory_pool_.find(sizeInBytes) != memory_pool_.end())
{ {
auto& q = memory_pool_[size]; auto& q = memory_pool_[sizeInBytes];
q.push(p); q.push(p);
debug_stream << "Deallocating size " << size << " and type " << typeid(T).name() << " to memory pool." << std::endl; memPoolSizeInBytes_ += sizeInBytes;
debug_stream << "\tPool size: " << q.size() << std::endl;
// If the memory pool size exceeds the maximum size, free the memory. // If the memory pool size exceeds the maximum size, free the memory.
if (q.size() > maxMemoryPoolSize_) if (memPoolSizeInBytes_ > maxMemoryPoolSizeInBytes_)
{ {
debug_stream << "Memory pool size exceeds the maximum size for type " << typeid(T).name() clearMemoryPoolQueue(q);
<< ". Freeing the memory." << std::endl;
while (!q.empty())
{
T* ptr = q.front();
q.pop();
hip_check_error(hipHostFree(ptr));
}
} }
} }
else { else {
debug_stream << "Creating new memory pool for size " << size << " and type " << typeid(T).name() << std::endl; std::queue<void*> q;
std::queue<T*> q;
q.push(p); q.push(p);
memory_pool_.insert(std::make_pair(size, std::move(q))); memory_pool_.insert(std::make_pair(sizeInBytes, std::move(q)));
memPoolSizeInBytes_ += sizeInBytes;
} }
} }
private: private:
constexpr static size_t maxMemoryPoolSizeInBytes_ = 1 << 20; // 1MB constexpr static size_t maxMemoryPoolSizeInBytes_ = 10 * 1024 * 1024; // 10MB
constexpr static size_t maxMemoryPoolSize_ = maxMemoryPoolSizeInBytes_ / sizeof(T);
static void clearMemoryPoolQueue(std::queue<void*>& q)
{
while (!q.empty())
{
void* p = q.front();
q.pop();
hip_check_error(hipHostFree(p));
}
}
std::mutex mutex_; // Mutex to protect access to the memory pool. std::mutex mutex_; // Mutex to protect access to the memory pool.
std::map<size_t, std::queue<T*>> memory_pool_{}; std::map<size_t, std::queue<void*>> memory_pool_{};
DebugStream debug_stream; size_t memPoolSizeInBytes_{0};
}; };
template <typename T> template <typename T>
...@@ -149,17 +104,15 @@ namespace memory { ...@@ -149,17 +104,15 @@ namespace memory {
{} {}
T* allocate(std::size_t n) { T* allocate(std::size_t n) {
// T* p;
// hip_check_error(hipHostMalloc(&p, n * sizeof(T)));
// return p;
auto& memory_pool = get_memory_pool(); auto& memory_pool = get_memory_pool();
return memory_pool.allocate(n); const size_t sizeInBytes = n * sizeof(T);
return static_cast<T*>(memory_pool.allocate(sizeInBytes));
} }
void deallocate(T* p, std::size_t size) { void deallocate(T* p, std::size_t n) {
//hip_check_error(hipHostFree(p));
auto& memory_pool = get_memory_pool(); auto& memory_pool = get_memory_pool();
memory_pool.deallocate(p, size); const size_t sizeInBytes = n * sizeof(T);
memory_pool.deallocate(p, sizeInBytes);
} }
template<typename U, typename... Args> template<typename U, typename... Args>
...@@ -172,9 +125,9 @@ namespace memory { ...@@ -172,9 +125,9 @@ namespace memory {
p->~U(); p->~U();
} }
private: private:
static MemPool<T>& get_memory_pool() { static MemPool& get_memory_pool() {
static MemPool<T> memory_pool_; static MemPool memory_pool;
return memory_pool_; return memory_pool;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment