Unverified Commit 0cc667e1 authored by akhoroshev's avatar akhoroshev Committed by GitHub
Browse files

[feature] Graceful termination of background threads in LlamaV2 (#458)

* cuda allocator fix

* graceful termination

* lint and compilation fix
parent ce9e0756
...@@ -126,6 +126,7 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -126,6 +126,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
template<typename T> template<typename T>
LlamaV2<T>::~LlamaV2() LlamaV2<T>::~LlamaV2()
{ {
shared_state_->request_queue.close();
internal_thread_.join(); internal_thread_.join();
delete decoder_; delete decoder_;
...@@ -448,12 +449,24 @@ void LlamaV2<T>::internalThreadEntry(int device_id) ...@@ -448,12 +449,24 @@ void LlamaV2<T>::internalThreadEntry(int device_id)
request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty); request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty);
// request queue was closed
// and there are no unprocessed requests in the queue
if (is_empty && infer_requests.empty() && stop_requests.empty()) {
// rank 0 sets flag
shared_state_->should_stop = true;
}
batch_.verifyRequests(stop_requests, infer_requests); batch_.verifyRequests(stop_requests, infer_requests);
} }
// wait while rank-0 is dequeueing // wait while rank-0 is dequeueing
shared_state_->barrier->wait(); shared_state_->barrier->wait();
// exit if job is done
if (shared_state_->should_stop) {
return;
}
bool modified = false; bool modified = false;
if (!(batch_.finishedCount() == 0 && stop_requests.empty() && infer_requests.empty())) { if (!(batch_.finishedCount() == 0 && stop_requests.empty() && infer_requests.empty())) {
...@@ -486,8 +499,6 @@ void LlamaV2<T>::internalThreadEntry(int device_id) ...@@ -486,8 +499,6 @@ void LlamaV2<T>::internalThreadEntry(int device_id)
batch_.finish(); batch_.finish();
} }
} }
FT_CHECK(0);
} }
template<typename T> template<typename T>
......
...@@ -46,6 +46,9 @@ public: ...@@ -46,6 +46,9 @@ public:
std::vector<std::shared_ptr<Request>> stop_requests; std::vector<std::shared_ptr<Request>> stop_requests;
RequestQueue request_queue; RequestQueue request_queue;
std::shared_ptr<Barrier> barrier; std::shared_ptr<Barrier> barrier;
// rank 0 sets flag to true if there are no more tasks in the request_queue
bool should_stop = false;
}; };
~LlamaV2(); ~LlamaV2();
......
...@@ -44,6 +44,11 @@ public: ...@@ -44,6 +44,11 @@ public:
futures.reserve(requests.size()); futures.reserve(requests.size());
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (closed_) {
throw std::runtime_error("Queue is closed");
}
for (auto& r : requests) { for (auto& r : requests) {
futures.push_back(r->signal.get_future()); futures.push_back(r->signal.get_future());
if (r->stop_flag) { if (r->stop_flag) {
...@@ -65,7 +70,7 @@ public: ...@@ -65,7 +70,7 @@ public:
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (blocking) { if (blocking) {
cv_.wait(lock, [this] { return !(stop_queue_.empty() && infer_queue_.empty()); }); cv_.wait(lock, [this] { return !(stop_queue_.empty() && infer_queue_.empty() && closed_ == false); });
} }
stop_requests.clear(); stop_requests.clear();
...@@ -81,11 +86,19 @@ public: ...@@ -81,11 +86,19 @@ public:
} }
} }
void close()
{
std::lock_guard<std::mutex> lock(mutex_);
closed_ = true;
cv_.notify_all();
}
private: private:
std::queue<std::shared_ptr<Request>> stop_queue_; std::queue<std::shared_ptr<Request>> stop_queue_;
std::queue<std::shared_ptr<Request>> infer_queue_; std::queue<std::shared_ptr<Request>> infer_queue_;
std::mutex mutex_; std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
bool closed_ = false;
}; };
} // namespace turbomind } // namespace turbomind
...@@ -125,9 +125,15 @@ class Allocator; ...@@ -125,9 +125,15 @@ class Allocator;
template<> template<>
class Allocator<AllocatorType::CUDA>: public IAllocator { class Allocator<AllocatorType::CUDA>: public IAllocator {
private: private:
const int device_id_; enum class MemoryType
cudaStream_t stream_ = 0; // initialize as default stream {
std::unordered_map<void*, size_t>* pointer_mapping_; HOST,
DEVICE
};
const int device_id_;
cudaStream_t stream_ = 0; // initialize as default stream
std::unordered_map<void*, std::pair<size_t, MemoryType>>* pointer_mapping_;
bool isExist(void* address) const bool isExist(void* address) const
{ {
...@@ -136,10 +142,10 @@ private: ...@@ -136,10 +142,10 @@ private:
ReallocType isReMalloc(void* address, size_t size) const ReallocType isReMalloc(void* address, size_t size) const
{ {
FT_CHECK(isExist(address)); FT_CHECK(isExist(address));
if (pointer_mapping_->at(address) < size) { if (pointer_mapping_->at(address).first < size) {
return ReallocType::INCREASE; return ReallocType::INCREASE;
} }
else if (pointer_mapping_->at(address) == size) { else if (pointer_mapping_->at(address).first == size) {
return ReallocType::REUSE; return ReallocType::REUSE;
} }
else { else {
...@@ -151,7 +157,7 @@ public: ...@@ -151,7 +157,7 @@ public:
Allocator(int device_id): device_id_(device_id) Allocator(int device_id): device_id_(device_id)
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
pointer_mapping_ = new std::unordered_map<void*, size_t>(); pointer_mapping_ = new std::unordered_map<void*, std::pair<size_t, MemoryType>>();
#if defined(CUDA_MEMORY_POOL_DISABLED) #if defined(CUDA_MEMORY_POOL_DISABLED)
TM_LOG_WARNING( TM_LOG_WARNING(
"Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free." "Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free."
...@@ -188,7 +194,9 @@ public: ...@@ -188,7 +194,9 @@ public:
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
while (!pointer_mapping_->empty()) { while (!pointer_mapping_->empty()) {
free((void**)(&pointer_mapping_->begin()->first)); auto ptr = pointer_mapping_->begin()->first;
auto size_and_type = pointer_mapping_->begin()->second;
free(&ptr, size_and_type.second == MemoryType::HOST);
} }
delete pointer_mapping_; delete pointer_mapping_;
} }
...@@ -229,18 +237,19 @@ public: ...@@ -229,18 +237,19 @@ public:
check_cuda_error(getSetDevice(o_device)); check_cuda_error(getSetDevice(o_device));
TM_LOG_DEBUG("malloc buffer %p with size %ld", ptr, size); TM_LOG_DEBUG("malloc buffer %p with size %ld", ptr, size);
pointer_mapping_->insert({getAddress(ptr), size}); pointer_mapping_->insert({getAddress(ptr), {size, is_host ? MemoryType::HOST : MemoryType::DEVICE}});
return ptr; return ptr;
} }
void free(void** ptr, bool is_host = false) const void free(void** ptr, bool _ = false) const
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
void* address = getAddress(*ptr); void* address = getAddress(*ptr);
if (*ptr != nullptr) { if (*ptr != nullptr) {
int o_device = 0; int o_device = 0;
if (pointer_mapping_->count(address)) { if (pointer_mapping_->count(address)) {
const auto is_host = pointer_mapping_->at(address).second == MemoryType::HOST;
TM_LOG_DEBUG("Free buffer %p", address); TM_LOG_DEBUG("Free buffer %p", address);
check_cuda_error(getSetDevice(device_id_, &o_device)); check_cuda_error(getSetDevice(device_id_, &o_device));
if (is_host) { if (is_host) {
...@@ -361,7 +370,7 @@ public: ...@@ -361,7 +370,7 @@ public:
{ {
while (!pointer_mapping_->empty()) { while (!pointer_mapping_->empty()) {
void* ptr = pointer_mapping_->begin()->second.flat<uint8>().data(); void* ptr = pointer_mapping_->begin()->second.flat<uint8>().data();
free((void**)(&ptr)); free(&ptr);
} }
pointer_mapping_->clear(); pointer_mapping_->clear();
delete pointer_mapping_; delete pointer_mapping_;
...@@ -454,7 +463,7 @@ public: ...@@ -454,7 +463,7 @@ public:
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
while (!pointer_mapping_->empty()) { while (!pointer_mapping_->empty()) {
void* ptr = pointer_mapping_->begin()->second.data_ptr(); void* ptr = pointer_mapping_->begin()->second.data_ptr();
free((void**)(&ptr)); free(&ptr);
} }
pointer_mapping_->clear(); pointer_mapping_->clear();
delete pointer_mapping_; delete pointer_mapping_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment