Commit 870c3a76 authored by Ville Pietilä's avatar Ville Pietilä
Browse files

Refactor mempool implementation. Fix a bug in allocating pinned host memory.

parent 0c7b35c4
......@@ -12,117 +12,72 @@
namespace ck {
namespace memory {
class DebugStream
{
public:
DebugStream(bool enable_output = true) : enable_output_(enable_output) {}
template <typename T>
DebugStream& operator<<(const T& value)
{
if (enable_output_)
{
std::cout << value;
}
return *this;
}
// Overload for std::ostream manipulators like std::endl
using Manipulator = std::ostream& (*)(std::ostream&);
DebugStream& operator<<(Manipulator manip)
{
if (enable_output_)
{
manip(std::cout);
}
return *this;
}
void enableOutput(bool enable)
{
enable_output_ = enable;
}
private:
bool enable_output_;
};
template <typename T>
class MemPool
{
public:
MemPool() : debug_stream(true) {}
MemPool() = default;
~MemPool()
{
std::lock_guard<std::mutex> lock(mutex_);
debug_stream << "Destroying memory pool of type " << typeid(T).name() << std::endl;
for (auto& [size, q] : memory_pool_)
{
// Iterate through the queue and free the memory
while (!q.empty())
{
T* p = q.front();
q.pop();
hip_check_error(hipHostFree(p));
}
clearMemoryPoolQueue(q);
}
}
T* allocate(std::size_t n)
void* allocate(std::size_t sizeInBytes)
{
std::lock_guard<std::mutex> lock(mutex_);
debug_stream << "Allocating size " << n << " for type " << typeid(T).name() << std::endl;
// If there is a memory pool for the requested size, return the memory from the pool.
if (memory_pool_.find(n) != memory_pool_.end() && !memory_pool_[n].empty())
if (memory_pool_.find(sizeInBytes) != memory_pool_.end() && !memory_pool_[sizeInBytes].empty())
{
debug_stream << "\tReturning from memory pool" << std::endl;
T* p = memory_pool_[n].front();
memory_pool_[n].pop();
void* p = memory_pool_[sizeInBytes].front();
memory_pool_[sizeInBytes].pop();
return p;
}
debug_stream << "\tAllocating new memory" << std::endl;
T* p;
hip_check_error(hipHostMalloc(&p, n));
void* p;
hip_check_error(hipHostMalloc(&p, sizeInBytes));
return p;
}
void deallocate(T* p, std::size_t size)
void deallocate(void* p, std::size_t sizeInBytes)
{
std::lock_guard<std::mutex> lock(mutex_);
if (memory_pool_.find(size) != memory_pool_.end())
if (memory_pool_.find(sizeInBytes) != memory_pool_.end())
{
auto& q = memory_pool_[size];
auto& q = memory_pool_[sizeInBytes];
q.push(p);
debug_stream << "Deallocating size " << size << " and type " << typeid(T).name() << " to memory pool." << std::endl;
debug_stream << "\tPool size: " << q.size() << std::endl;
memPoolSizeInBytes_ += sizeInBytes;
// If the memory pool size exceeds the maximum size, free the memory.
if (q.size() > maxMemoryPoolSize_)
if (memPoolSizeInBytes_ > maxMemoryPoolSizeInBytes_)
{
debug_stream << "Memory pool size exceeds the maximum size for type " << typeid(T).name()
<< ". Freeing the memory." << std::endl;
while (!q.empty())
{
T* ptr = q.front();
q.pop();
hip_check_error(hipHostFree(ptr));
}
clearMemoryPoolQueue(q);
}
}
else {
debug_stream << "Creating new memory pool for size " << size << " and type " << typeid(T).name() << std::endl;
std::queue<T*> q;
std::queue<void*> q;
q.push(p);
memory_pool_.insert(std::make_pair(size, std::move(q)));
memory_pool_.insert(std::make_pair(sizeInBytes, std::move(q)));
memPoolSizeInBytes_ += sizeInBytes;
}
}
private:
constexpr static size_t maxMemoryPoolSizeInBytes_ = 1 << 20; // 1MB
constexpr static size_t maxMemoryPoolSize_ = maxMemoryPoolSizeInBytes_ / sizeof(T);
constexpr static size_t maxMemoryPoolSizeInBytes_ = 10 * 1024 * 1024; // 10MB
static void clearMemoryPoolQueue(std::queue<void*>& q)
{
while (!q.empty())
{
void* p = q.front();
q.pop();
hip_check_error(hipHostFree(p));
}
}
std::mutex mutex_; // Mutex to protect access to the memory pool.
std::map<size_t, std::queue<T*>> memory_pool_{};
DebugStream debug_stream;
std::map<size_t, std::queue<void*>> memory_pool_{};
size_t memPoolSizeInBytes_{0};
};
template <typename T>
......@@ -149,17 +104,15 @@ namespace memory {
{}
T* allocate(std::size_t n) {
// T* p;
// hip_check_error(hipHostMalloc(&p, n * sizeof(T)));
// return p;
auto& memory_pool = get_memory_pool();
return memory_pool.allocate(n);
const size_t sizeInBytes = n * sizeof(T);
return static_cast<T*>(memory_pool.allocate(sizeInBytes));
}
void deallocate(T* p, std::size_t size) {
//hip_check_error(hipHostFree(p));
void deallocate(T* p, std::size_t n) {
auto& memory_pool = get_memory_pool();
memory_pool.deallocate(p, size);
const size_t sizeInBytes = n * sizeof(T);
memory_pool.deallocate(p, sizeInBytes);
}
template<typename U, typename... Args>
......@@ -172,9 +125,9 @@ namespace memory {
p->~U();
}
private:
static MemPool<T>& get_memory_pool() {
static MemPool<T> memory_pool_;
return memory_pool_;
static MemPool& get_memory_pool() {
static MemPool memory_pool;
return memory_pool;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment