"examples/dreambooth/train_dreambooth_lora.py" did not exist on "badddee0ef8282e07232dcd3990070e2d524f57c"
Commit aaba2033 authored by Ville Pietilä's avatar Ville Pietilä
Browse files

Proper clean-up mechanism for pinned host memory.

parent 5755f841
...@@ -610,6 +610,8 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout, ...@@ -610,6 +610,8 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
ck::memory::PinnedHostMemoryDeallocator::instance().destruct_device(
static_cast<const void*>(arg.gemm_desc_kernel_arg_.data()), stream_config.stream_id_);
auto launch_kernel = [&](auto has_main_k_block_loop, auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) { auto has_double_tail_k_block_loop) {
......
...@@ -767,6 +767,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -767,6 +767,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
ck::memory::PinnedHostMemoryDeallocator::instance().destruct_device(
static_cast<const void*>(arg.gemm_kernel_args_.data()), stream_config.stream_id_);
auto preprocess = [&]() { auto preprocess = [&]() {
hip_check_error(hipMemsetAsync( hip_check_error(hipMemsetAsync(
......
...@@ -564,6 +564,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -564,6 +564,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
ck::memory::PinnedHostMemoryDeallocator::instance().destruct_device(
static_cast<const void*>(arg.gemm_desc_kernel_arg_.data()), stream_config.stream_id_);
float ave_time = 0; float ave_time = 0;
......
...@@ -427,6 +427,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -427,6 +427,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
ck::memory::PinnedHostMemoryDeallocator::instance().destruct_device(
static_cast<const void*>(arg.gemm_kernel_args_.data()), stream_config.stream_id_);
float ave_time = 0; float ave_time = 0;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
CK_DECLARE_ENV_VAR_BOOL(CK_USE_DYNAMIC_MEM_POOL) CK_DECLARE_ENV_VAR_BOOL(CK_USE_DYNAMIC_MEM_POOL)
CK_DECLARE_ENV_VAR_BOOL(CK_PREFER_RECYCLED_PINNED_MEM) CK_DECLARE_ENV_VAR_BOOL(CK_PREFER_RECYCLED_PINNED_MEM)
CK_DECLARE_ENV_VAR_UINT64(CK_PINNED_MEM_SIZE_KB)
namespace ck { namespace ck {
namespace memory { namespace memory {
...@@ -141,7 +142,7 @@ namespace memory { ...@@ -141,7 +142,7 @@ namespace memory {
#endif #endif
} }
private: private:
constexpr static size_t defaultMaxMemoryPoolSizeInBytes_ = 10 * 1024 * 1024; // 10MB constexpr static size_t defaultMaxMemoryPoolSizeInBytes_ = 1 * 1024 * 1024; // 1MB
void clearMemoryPoolQueue(size_t sizeInBytes) void clearMemoryPoolQueue(size_t sizeInBytes)
{ {
...@@ -169,9 +170,16 @@ namespace memory { ...@@ -169,9 +170,16 @@ namespace memory {
pid_(getpid()), pid_(getpid()),
offsetInBytes_(0), offsetInBytes_(0),
preferRecycledMem_(ck::EnvIsEnabled(CK_ENV(CK_PREFER_RECYCLED_PINNED_MEM))), preferRecycledMem_(ck::EnvIsEnabled(CK_ENV(CK_PREFER_RECYCLED_PINNED_MEM))),
memoryPoolSizeInBytes_(poolSizeInBytes) activeMemoryPoolSizeInBytes_(poolSizeInBytes)
{
if (!ck::EnvIsUnset(CK_ENV(CK_PINNED_MEM_SIZE_KB)))
{ {
allocateNewPinnedMemoryBlock(); // kB to bytes conversion
constexpr size_t KB = 1024;
activeMemoryPoolSizeInBytes_ = ck::EnvValue(CK_ENV(CK_PINNED_MEM_SIZE_KB)) * KB;
std::cout << "[ StaticMemPool ] Override of default memory size to " << activeMemoryPoolSizeInBytes_ << " bytes." << std::endl;
}
allocateNewPinnedMemoryBlock(activeMemoryPoolSizeInBytes_);
} }
~StaticMemPool() override ~StaticMemPool() override
...@@ -192,7 +200,7 @@ namespace memory { ...@@ -192,7 +200,7 @@ namespace memory {
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (!preferRecycledMem_ && offsetInBytes_ + sizeInBytes - 1 < memoryPoolSizeInBytes_) if (!preferRecycledMem_ && offsetInBytes_ + sizeInBytes - 1 < activeMemoryPoolSizeInBytes_)
{ {
return allocateNewMemory(sizeInBytes); return allocateNewMemory(sizeInBytes);
} }
...@@ -203,14 +211,14 @@ namespace memory { ...@@ -203,14 +211,14 @@ namespace memory {
return ptr; return ptr;
} }
if (offsetInBytes_ + sizeInBytes - 1 < memoryPoolSizeInBytes_) if (offsetInBytes_ + sizeInBytes - 1 < activeMemoryPoolSizeInBytes_)
{ {
return allocateNewMemory(sizeInBytes); return allocateNewMemory(sizeInBytes);
} }
// Memory became too fragmented, reserve a new block. // Memory became too fragmented, reserve a new block.
// This should not happen very often, practically never. size_t requestedBlockSize = std::max(activeMemoryPoolSizeInBytes_, 2*sizeInBytes);
allocateNewPinnedMemoryBlock(); allocateNewPinnedMemoryBlock(requestedBlockSize);
return allocateNewMemory(sizeInBytes); return allocateNewMemory(sizeInBytes);
} }
...@@ -255,7 +263,7 @@ namespace memory { ...@@ -255,7 +263,7 @@ namespace memory {
size_t memoryPoolSizeInBytes() const size_t memoryPoolSizeInBytes() const
{ {
return memoryPoolSizeInBytes_; return activeMemoryPoolSizeInBytes_;
} }
const std::map<size_t, std::queue<void*>>& memoryPool() const const std::map<size_t, std::queue<void*>>& memoryPool() const
...@@ -265,24 +273,25 @@ namespace memory { ...@@ -265,24 +273,25 @@ namespace memory {
private: private:
constexpr static size_t defaultMaxMemoryPoolSizeInBytes_ = 10 * 1024 * 1024; // 10MB constexpr static size_t defaultMaxMemoryPoolSizeInBytes_ = 10 * 1024 * 1024; // 10MB
std::mutex mutex_; // Mutex to protect access to the memory pool. std::mutex mutex_;
std::map<size_t, std::queue<void*>> memory_pool_{}; std::map<size_t, std::queue<void*>> memory_pool_{};
std::stack<std::byte*> pinnedMemoryBaseAddress_; std::stack<std::byte*> pinnedMemoryBaseAddress_;
bool enableLogging_; bool enableLogging_;
int pid_; int pid_;
int offsetInBytes_; int offsetInBytes_;
bool preferRecycledMem_; bool preferRecycledMem_;
size_t memoryPoolSizeInBytes_; size_t activeMemoryPoolSizeInBytes_;
void allocateNewPinnedMemoryBlock() void allocateNewPinnedMemoryBlock(size_t memoryPoolSizeInBytes)
{ {
activeMemoryPoolSizeInBytes_ = memoryPoolSizeInBytes;
std::byte* pinnedMemoryBaseAddress; std::byte* pinnedMemoryBaseAddress;
hip_check_error(hipHostMalloc(&pinnedMemoryBaseAddress, memoryPoolSizeInBytes_)); hip_check_error(hipHostMalloc(&pinnedMemoryBaseAddress, activeMemoryPoolSizeInBytes_));
pinnedMemoryBaseAddress_.push(pinnedMemoryBaseAddress); pinnedMemoryBaseAddress_.push(pinnedMemoryBaseAddress);
offsetInBytes_ = 0; offsetInBytes_ = 0;
if (enableLogging_) if (enableLogging_)
{ {
std::cout << "[ StaticMemPool ] Allocation: Created new pinned memory block of " << memoryPoolSizeInBytes_ << " bytes." << std::endl; std::cout << "[ StaticMemPool ] Allocation: Created new pinned memory block of " << activeMemoryPoolSizeInBytes_ << " bytes." << std::endl;
} }
} }
...@@ -294,7 +303,7 @@ namespace memory { ...@@ -294,7 +303,7 @@ namespace memory {
#ifdef ENABLE_MEM_POOL_LOGGING #ifdef ENABLE_MEM_POOL_LOGGING
if (enableLogging_) if (enableLogging_)
{ {
const auto pct = 100.0f * static_cast<float>(offsetInBytes_) / memoryPoolSizeInBytes_; const auto pct = 100.0f * static_cast<float>(offsetInBytes_) / activeMemoryPoolSizeInBytes_;
std::cout << "[ StaticMemPool ] Allocation: Return new memory of " << sizeInBytes << std::cout << "[ StaticMemPool ] Allocation: Return new memory of " << sizeInBytes <<
" bytes, pinned host memory usage: " << pct << "%." << std::endl; " bytes, pinned host memory usage: " << pct << "%." << std::endl;
} }
...@@ -354,10 +363,144 @@ namespace memory { ...@@ -354,10 +363,144 @@ namespace memory {
{ {
public: public:
IMemPool* get_memory_pool() { IMemPool* get_memory_pool() {
static DynamicMemPool dynamic_memory_pool; //static DynamicMemPool dynamic_memory_pool;
static StaticMemPool static_memory_pool; static StaticMemPool static_memory_pool;
static bool use_dynamic_mem_pool = ck::EnvIsEnabled(CK_ENV(CK_USE_DYNAMIC_MEM_POOL)); //static bool use_dynamic_mem_pool = ck::EnvIsEnabled(CK_ENV(CK_USE_DYNAMIC_MEM_POOL));
return use_dynamic_mem_pool ? static_cast<IMemPool*>(&dynamic_memory_pool) : static_cast<IMemPool*>(&static_memory_pool); //return use_dynamic_mem_pool ? static_cast<IMemPool*>(&dynamic_memory_pool) : static_cast<IMemPool*>(&static_memory_pool);
return &static_memory_pool;
}
};
class MemoryCleanupThread
{
public:
MemoryCleanupThread(std::function<void()> cleanup_function) : cleanup_callback_(cleanup_function)
{
cleanup_thread_ = std::thread([this]() {
while (!should_stop_) {
std::this_thread::sleep_for(std::chrono::milliseconds(50));
try
{
cleanup_callback_();
}
catch (const std::exception& e)
{
std::cerr << "Error in cleanup thread: " << e.what() << std::endl;
}
catch (...)
{
std::cerr << "Error in cleanup thread." << std::endl;
}
}
});
}
~MemoryCleanupThread() {
should_stop_ = true;
if(cleanup_thread_.joinable()) {
cleanup_thread_.join();
}
}
MemoryCleanupThread(const MemoryCleanupThread&) = delete;
MemoryCleanupThread& operator=(const MemoryCleanupThread&) = delete;
MemoryCleanupThread(MemoryCleanupThread&&) noexcept = default;
MemoryCleanupThread& operator=(MemoryCleanupThread&&) noexcept = default;
private:
std::function<void()> cleanup_callback_;
std::thread cleanup_thread_;
bool should_stop_{false};
};
class PinnedHostMemoryDeallocator : public PinnedHostMemoryAllocatorBase
{
public:
PinnedHostMemoryDeallocator() : cleanup_thread_([this]() { deallocate_all(); })
{
}
void register_allocated_memory(void* p, size_t sizeInBytes)
{
std::lock_guard<std::mutex> lock(mutex_);
hipEvent_t event;
hip_check_error(hipEventCreate(&event));
device_destruct_events_.insert({p, event});
allocated_memory_.insert({p, sizeInBytes});
host_destruct_events_.insert({p, false});
}
void destruct_host(void* p /*, std::function<void()>&& destructor*/)
{
std::lock_guard<std::mutex> lock(mutex_);
host_destruct_events_[p] = true;
}
void destruct_device(const void* p, hipStream_t stream)
{
std::lock_guard<std::mutex> lock(mutex_);
hip_check_error(hipEventRecord(device_destruct_events_[const_cast<void*>(p)], stream));
}
void deallocate_all()
{
std::lock_guard<std::mutex> lock(mutex_);
std::vector<void*> keys;
for (const auto& [p, _] : allocated_memory_)
{
keys.push_back(p);
}
for (auto p : keys)
{
if (canDeallocate(p))
{
deallocate(p);
}
}
}
static PinnedHostMemoryDeallocator& instance()
{
static PinnedHostMemoryDeallocator instance;
return instance;
}
private:
std::mutex mutex_;
std::map<void*, std::size_t> allocated_memory_;
std::map<void*, bool> host_destruct_events_;
std::map<void*, hipEvent_t> device_destruct_events_;
MemoryCleanupThread cleanup_thread_;
void deallocate(void* p)
{
//destructors_[p]();
host_destruct_events_.erase(p);
auto* memory_pool = get_memory_pool();
memory_pool->deallocate(p, allocated_memory_[p]);
hip_check_error(hipEventDestroy(device_destruct_events_[p]));
device_destruct_events_.erase(p);
allocated_memory_.erase(p);
}
bool canDeallocate(void* p)
{
bool can_deallocate_on_device = false;
if (device_destruct_events_.find(p) != device_destruct_events_.end())
{
hipError_t state = hipEventQuery(device_destruct_events_[p]);
if (state == hipSuccess)
{
can_deallocate_on_device = true;
}
else if (state != hipErrorNotReady)
{
throw std::runtime_error("Error querying event state: " + std::to_string(state));
}
}
const bool can_deallocate_on_host = host_destruct_events_[p];
return can_deallocate_on_device && can_deallocate_on_host;
} }
}; };
...@@ -384,24 +527,26 @@ namespace memory { ...@@ -384,24 +527,26 @@ namespace memory {
PinnedHostMemoryAllocator(const PinnedHostMemoryAllocator<U>&) PinnedHostMemoryAllocator(const PinnedHostMemoryAllocator<U>&)
{} {}
T* allocate(std::size_t n) { T* allocate(std::size_t n)
{
auto* memory_pool = get_memory_pool(); auto* memory_pool = get_memory_pool();
const size_t sizeInBytes = n * sizeof(T); const size_t sizeInBytes = n * sizeof(T);
return static_cast<T*>(memory_pool->allocate(sizeInBytes)); T* p = static_cast<T*>(memory_pool->allocate(sizeInBytes));
PinnedHostMemoryDeallocator::instance().register_allocated_memory(p, sizeInBytes);
return p;
} }
void deallocate(T* p, std::size_t n) void deallocate(T* p, std::size_t)
{ {
if constexpr (std::is_destructible_v<T>) // auto destructor = [&]() {
{ // if constexpr (std::is_destructible_v<T>)
for (size_t i = 0; i < n; ++i) { // {
p[i].~T(); // for (size_t i = 0; i < n; ++i) {
} // p[i].~T();
} // }
// }
auto* memory_pool = get_memory_pool(); // };
const size_t sizeInBytes = n * sizeof(T); PinnedHostMemoryDeallocator::instance().destruct_host(p /*, std::move(destructor)*/);
memory_pool->deallocate(p, sizeInBytes);
} }
template<typename U, typename... Args> template<typename U, typename... Args>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment