Commit 0c7b35c4 authored by Ville Pietilä's avatar Ville Pietilä
Browse files

Improve logging.

parent 6787ca76
...@@ -6,7 +6,11 @@ ...@@ -6,7 +6,11 @@
"extensions": ["ms-vscode.cpptools-extension-pack", "eamodio.gitlens", "ms-python.python", "benjamin-simmonds.pythoncpp-debug"] "extensions": ["ms-vscode.cpptools-extension-pack", "eamodio.gitlens", "ms-python.python", "benjamin-simmonds.pythoncpp-debug"]
} }
}, },
"containerEnv": {"LANG": "C.UTF-8"}, "containerEnv":
{
"LANG": "C.UTF-8",
"HIP_VISIBLE_DEVICES": "6,7"
},
"mounts": [ "mounts": [
"source=${localEnv:HOME}/.ssh,target=/root/.ssh,type=bind,consistency=cached" "source=${localEnv:HOME}/.ssh,target=/root/.ssh,type=bind,consistency=cached"
] ]
......
...@@ -5,12 +5,128 @@ ...@@ -5,12 +5,128 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include "ck/host_utility/hip_check_error.hpp" #include "ck/host_utility/hip_check_error.hpp"
#include <map>
#include <queue>
#include <mutex>
namespace ck { namespace ck {
namespace memory { namespace memory {
class DebugStream
{
public:
DebugStream(bool enable_output = true) : enable_output_(enable_output) {}
template <typename T>
DebugStream& operator<<(const T& value)
{
if (enable_output_)
{
std::cout << value;
}
return *this;
}
// Overload for std::ostream manipulators like std::endl
using Manipulator = std::ostream& (*)(std::ostream&);
DebugStream& operator<<(Manipulator manip)
{
if (enable_output_)
{
manip(std::cout);
}
return *this;
}
void enableOutput(bool enable)
{
enable_output_ = enable;
}
private:
bool enable_output_;
};
template <typename T>
class MemPool
{
public:
MemPool() : debug_stream(true) {}
~MemPool()
{
std::lock_guard<std::mutex> lock(mutex_);
debug_stream << "Destroying memory pool of type " << typeid(T).name() << std::endl;
for (auto& [size, q] : memory_pool_)
{
// Iterate through the queue and free the memory
while (!q.empty())
{
T* p = q.front();
q.pop();
hip_check_error(hipHostFree(p));
}
}
}
T* allocate(std::size_t n)
{
std::lock_guard<std::mutex> lock(mutex_);
debug_stream << "Allocating size " << n << " for type " << typeid(T).name() << std::endl;
// If there is a memory pool for the requested size, return the memory from the pool.
if (memory_pool_.find(n) != memory_pool_.end() && !memory_pool_[n].empty())
{
debug_stream << "\tReturning from memory pool" << std::endl;
T* p = memory_pool_[n].front();
memory_pool_[n].pop();
return p;
}
debug_stream << "\tAllocating new memory" << std::endl;
T* p;
hip_check_error(hipHostMalloc(&p, n));
return p;
}
void deallocate(T* p, std::size_t size)
{
std::lock_guard<std::mutex> lock(mutex_);
if (memory_pool_.find(size) != memory_pool_.end())
{
auto& q = memory_pool_[size];
q.push(p);
debug_stream << "Deallocating size " << size << " and type " << typeid(T).name() << " to memory pool." << std::endl;
debug_stream << "\tPool size: " << q.size() << std::endl;
// If the memory pool size exceeds the maximum size, free the memory.
if (q.size() > maxMemoryPoolSize_)
{
debug_stream << "Memory pool size exceeds the maximum size for type " << typeid(T).name()
<< ". Freeing the memory." << std::endl;
while (!q.empty())
{
T* ptr = q.front();
q.pop();
hip_check_error(hipHostFree(ptr));
}
}
}
else {
debug_stream << "Creating new memory pool for size " << size << " and type " << typeid(T).name() << std::endl;
std::queue<T*> q;
q.push(p);
memory_pool_.insert(std::make_pair(size, std::move(q)));
}
}
private:
constexpr static size_t maxMemoryPoolSizeInBytes_ = 1 << 20; // 1MB
constexpr static size_t maxMemoryPoolSize_ = maxMemoryPoolSizeInBytes_ / sizeof(T);
std::mutex mutex_; // Mutex to protect access to the memory pool.
std::map<size_t, std::queue<T*>> memory_pool_{};
DebugStream debug_stream;
};
template <typename T> template <typename T>
struct PinnedHostMemoryAllocator class PinnedHostMemoryAllocator
{ {
public: public:
using value_type = T; using value_type = T;
...@@ -33,13 +149,17 @@ namespace memory { ...@@ -33,13 +149,17 @@ namespace memory {
{} {}
T* allocate(std::size_t n) { T* allocate(std::size_t n) {
T* p; // T* p;
hip_check_error(hipHostMalloc(&p, n * sizeof(T))); // hip_check_error(hipHostMalloc(&p, n * sizeof(T)));
return p; // return p;
auto& memory_pool = get_memory_pool();
return memory_pool.allocate(n);
} }
void deallocate(T* p, std::size_t) { void deallocate(T* p, std::size_t size) {
hip_check_error(hipHostFree(p)); //hip_check_error(hipHostFree(p));
auto& memory_pool = get_memory_pool();
memory_pool.deallocate(p, size);
} }
template<typename U, typename... Args> template<typename U, typename... Args>
...@@ -51,6 +171,11 @@ namespace memory { ...@@ -51,6 +171,11 @@ namespace memory {
void destroy(U* p) noexcept { void destroy(U* p) noexcept {
p->~U(); p->~U();
} }
private:
static MemPool<T>& get_memory_pool() {
static MemPool<T> memory_pool_;
return memory_pool_;
}
}; };
template <typename T, typename U> template <typename T, typename U>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment