Unverified Commit 0cc667e1 authored by akhoroshev's avatar akhoroshev Committed by GitHub
Browse files

[feature] Graceful termination of background threads in LlamaV2 (#458)

* cuda allocator fix

* graceful termination

* lint and compilation fix
parent ce9e0756
......@@ -126,6 +126,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
template<typename T>
LlamaV2<T>::~LlamaV2()
{
shared_state_->request_queue.close();
internal_thread_.join();
delete decoder_;
......@@ -448,12 +449,24 @@ void LlamaV2<T>::internalThreadEntry(int device_id)
request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty);
// request queue was closed
// and there are no unprocessed requests in the queue
if (is_empty && infer_requests.empty() && stop_requests.empty()) {
// rank 0 sets flag
shared_state_->should_stop = true;
}
batch_.verifyRequests(stop_requests, infer_requests);
}
// wait while rank-0 is dequeueing
shared_state_->barrier->wait();
// exit if job is done
if (shared_state_->should_stop) {
return;
}
bool modified = false;
if (!(batch_.finishedCount() == 0 && stop_requests.empty() && infer_requests.empty())) {
......@@ -486,8 +499,6 @@ void LlamaV2<T>::internalThreadEntry(int device_id)
batch_.finish();
}
}
FT_CHECK(0);
}
template<typename T>
......
......@@ -46,6 +46,9 @@ public:
std::vector<std::shared_ptr<Request>> stop_requests;
RequestQueue request_queue;
std::shared_ptr<Barrier> barrier;
// rank 0 sets flag to true if there are no more tasks in the request_queue
bool should_stop = false;
};
~LlamaV2();
......
......@@ -44,6 +44,11 @@ public:
futures.reserve(requests.size());
{
std::lock_guard<std::mutex> lock(mutex_);
if (closed_) {
throw std::runtime_error("Queue is closed");
}
for (auto& r : requests) {
futures.push_back(r->signal.get_future());
if (r->stop_flag) {
......@@ -65,7 +70,7 @@ public:
{
std::unique_lock<std::mutex> lock(mutex_);
if (blocking) {
cv_.wait(lock, [this] { return !(stop_queue_.empty() && infer_queue_.empty()); });
cv_.wait(lock, [this] { return !(stop_queue_.empty() && infer_queue_.empty() && closed_ == false); });
}
stop_requests.clear();
......@@ -81,11 +86,19 @@ public:
}
}
void close()
{
std::lock_guard<std::mutex> lock(mutex_);
closed_ = true;
cv_.notify_all();
}
private:
std::queue<std::shared_ptr<Request>> stop_queue_;
std::queue<std::shared_ptr<Request>> infer_queue_;
std::mutex mutex_;
std::condition_variable cv_;
bool closed_ = false;
};
} // namespace turbomind
......@@ -125,9 +125,15 @@ class Allocator;
template<>
class Allocator<AllocatorType::CUDA>: public IAllocator {
private:
const int device_id_;
cudaStream_t stream_ = 0; // initialize as default stream
std::unordered_map<void*, size_t>* pointer_mapping_;
enum class MemoryType
{
HOST,
DEVICE
};
const int device_id_;
cudaStream_t stream_ = 0; // initialize as default stream
std::unordered_map<void*, std::pair<size_t, MemoryType>>* pointer_mapping_;
bool isExist(void* address) const
{
......@@ -136,10 +142,10 @@ private:
ReallocType isReMalloc(void* address, size_t size) const
{
FT_CHECK(isExist(address));
if (pointer_mapping_->at(address) < size) {
if (pointer_mapping_->at(address).first < size) {
return ReallocType::INCREASE;
}
else if (pointer_mapping_->at(address) == size) {
else if (pointer_mapping_->at(address).first == size) {
return ReallocType::REUSE;
}
else {
......@@ -151,7 +157,7 @@ public:
Allocator(int device_id): device_id_(device_id)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
pointer_mapping_ = new std::unordered_map<void*, size_t>();
pointer_mapping_ = new std::unordered_map<void*, std::pair<size_t, MemoryType>>();
#if defined(CUDA_MEMORY_POOL_DISABLED)
TM_LOG_WARNING(
"Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free."
......@@ -188,7 +194,9 @@ public:
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
while (!pointer_mapping_->empty()) {
free((void**)(&pointer_mapping_->begin()->first));
auto ptr = pointer_mapping_->begin()->first;
auto size_and_type = pointer_mapping_->begin()->second;
free(&ptr, size_and_type.second == MemoryType::HOST);
}
delete pointer_mapping_;
}
......@@ -229,18 +237,19 @@ public:
check_cuda_error(getSetDevice(o_device));
TM_LOG_DEBUG("malloc buffer %p with size %ld", ptr, size);
pointer_mapping_->insert({getAddress(ptr), size});
pointer_mapping_->insert({getAddress(ptr), {size, is_host ? MemoryType::HOST : MemoryType::DEVICE}});
return ptr;
}
void free(void** ptr, bool is_host = false) const
void free(void** ptr, bool _ = false) const
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
void* address = getAddress(*ptr);
if (*ptr != nullptr) {
int o_device = 0;
if (pointer_mapping_->count(address)) {
const auto is_host = pointer_mapping_->at(address).second == MemoryType::HOST;
TM_LOG_DEBUG("Free buffer %p", address);
check_cuda_error(getSetDevice(device_id_, &o_device));
if (is_host) {
......@@ -361,7 +370,7 @@ public:
{
while (!pointer_mapping_->empty()) {
void* ptr = pointer_mapping_->begin()->second.flat<uint8>().data();
free((void**)(&ptr));
free(&ptr);
}
pointer_mapping_->clear();
delete pointer_mapping_;
......@@ -454,7 +463,7 @@ public:
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
while (!pointer_mapping_->empty()) {
void* ptr = pointer_mapping_->begin()->second.data_ptr();
free((void**)(&ptr));
free(&ptr);
}
pointer_mapping_->clear();
delete pointer_mapping_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment