Unverified Commit 19fea86c authored by Lyu Han's avatar Lyu Han Committed by GitHub
Browse files

Change `shared_instance` type from `weakptr` to `shared_ptr` (#507)

* change shared_instances_ from weakptr to sharedptr

* update
parent 02684144
...@@ -273,7 +273,7 @@ LlamaTritonModel<T>::createModelInstance(int ...@@ -273,7 +273,7 @@ LlamaTritonModel<T>::createModelInstance(int
std::shared_ptr<LlamaTritonSharedModelInstance<T>> instance; std::shared_ptr<LlamaTritonSharedModelInstance<T>> instance;
{ {
std::lock_guard<std::mutex> lock(shared_mutexes_[device_id]); std::lock_guard<std::mutex> lock(shared_mutexes_[device_id]);
instance = shared_instances_[device_id].lock(); instance = shared_instances_[device_id];
if (!instance) { if (!instance) {
instance = createSharedModelInstance(device_id, rank, nccl_params, custom_all_reduce_comm); instance = createSharedModelInstance(device_id, rank, nccl_params, custom_all_reduce_comm);
instance->llm->setFfiLock(ffi_lock_); instance->llm->setFfiLock(ffi_lock_);
...@@ -347,7 +347,7 @@ LlamaTritonModel<T>::createNcclParams(const int node_id, const int device_id_sta ...@@ -347,7 +347,7 @@ LlamaTritonModel<T>::createNcclParams(const int node_id, const int device_id_sta
// create nccl group when there are non-occupied devices // create nccl group when there are non-occupied devices
for (int i = 0; i < device_count; ++i) { for (int i = 0; i < device_count; ++i) {
std::lock_guard<std::mutex> lock(shared_mutexes_[i]); std::lock_guard<std::mutex> lock(shared_mutexes_[i]);
if (shared_instances_[i].expired()) { if (shared_instances_[i] == nullptr) {
need_nccl_params = true; need_nccl_params = true;
break; break;
} }
......
...@@ -108,9 +108,8 @@ private: ...@@ -108,9 +108,8 @@ private:
std::shared_ptr<typename ft::LlamaV2<T>::SharedState> shared_state_; std::shared_ptr<typename ft::LlamaV2<T>::SharedState> shared_state_;
// weak_ptr is used so that the instances get released when all strong references are gone std::vector<std::shared_ptr<LlamaTritonSharedModelInstance<T>>> shared_instances_;
std::vector<std::weak_ptr<LlamaTritonSharedModelInstance<T>>> shared_instances_; std::deque<std::mutex> shared_mutexes_; // is locking really needed?
std::deque<std::mutex> shared_mutexes_; // is locking really needed?
bool is_fp16_; bool is_fp16_;
int enable_custom_all_reduce_ = 0; int enable_custom_all_reduce_ = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment