Unverified Commit a54e3e09 authored by akhoroshev's avatar akhoroshev Committed by GitHub
Browse files

fix race condition (#460)

parent 327deaee
......@@ -30,6 +30,9 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r
auto invalidate = [](const char* type, std::shared_ptr<Request>& req, int ec) {
TM_LOG_WARNING("[verifyRequests] Skipping invalid %s request for id %ld, code = %d", type, (long)req->id, ec);
// We don't need a barrier there because
// this lambda is called only for new requests
// which are visible only for rank = 0 thread.
req->signal.set_value(ec);
req.reset();
};
......@@ -139,6 +142,12 @@ void LlamaBatch<T>::handleStopRequests(const std::vector<std::shared_ptr<Request
check_cuda_error(cudaMemsetAsync(sequence_length.getPtr<int>(), 0, sizeof(int), stream_));
check_cuda_error(cudaStreamSynchronize(stream_));
}
// When the signal is set threads from LlamaV2::forward can exit
// and free inputs/outputs tensors.
// Therefore we need to make sure that no threads from LlamaV2::internalThreadEntry
// are accessing the tensors.
llama_->shared_state_->barrier->wait();
if (rank_ == 0) {
r->signal.set_value(ec);
}
......@@ -1112,6 +1121,11 @@ void LlamaBatch<T>::finishRequest(int index, bool force_end)
llama_->kv_cache_mgr_->update(cached_seq_[index], stream_);
}
// When the signal is set threads from LlamaV2::forward can exit
// and free inputs/outputs tensors.
// Therefore we need to make sure that no threads from LlamaV2::internalThreadEntry
// are accessing the tensors.
llama_->shared_state_->barrier->wait();
if (rank_ == 0) {
requests_[index]->signal.set_value(0);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment