Unverified Commit 8065fe12 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: Refactor discovery ModelManager to use `parking_lot::RwLock` (#2902)


Signed-off-by: default avatarPaul Hendricks <phendricks@nvidia.com>
parent 85b3dd44
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
sync::{Arc, RwLock}, sync::Arc,
}; };
use parking_lot::Mutex; use parking_lot::{Mutex, RwLock};
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::prelude::DistributedRuntimeProvider; use dynamo_runtime::prelude::DistributedRuntimeProvider;
...@@ -64,8 +64,8 @@ impl ModelManager { ...@@ -64,8 +64,8 @@ impl ModelManager {
} }
pub fn has_model_any(&self, model: &str) -> bool { pub fn has_model_any(&self, model: &str) -> bool {
self.chat_completion_engines.read().unwrap().contains(model) self.chat_completion_engines.read().contains(model)
|| self.completion_engines.read().unwrap().contains(model) || self.completion_engines.read().contains(model)
} }
pub fn model_display_names(&self) -> HashSet<String> { pub fn model_display_names(&self) -> HashSet<String> {
...@@ -77,15 +77,15 @@ impl ModelManager { ...@@ -77,15 +77,15 @@ impl ModelManager {
} }
pub fn list_chat_completions_models(&self) -> Vec<String> { pub fn list_chat_completions_models(&self) -> Vec<String> {
self.chat_completion_engines.read().unwrap().list() self.chat_completion_engines.read().list()
} }
pub fn list_completions_models(&self) -> Vec<String> { pub fn list_completions_models(&self) -> Vec<String> {
self.completion_engines.read().unwrap().list() self.completion_engines.read().list()
} }
pub fn list_embeddings_models(&self) -> Vec<String> { pub fn list_embeddings_models(&self) -> Vec<String> {
self.embeddings_engines.read().unwrap().list() self.embeddings_engines.read().list()
} }
pub fn add_completions_model( pub fn add_completions_model(
...@@ -93,7 +93,7 @@ impl ModelManager { ...@@ -93,7 +93,7 @@ impl ModelManager {
model: &str, model: &str,
engine: OpenAICompletionsStreamingEngine, engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write().unwrap(); let mut clients = self.completion_engines.write();
clients.add(model, engine) clients.add(model, engine)
} }
...@@ -102,7 +102,7 @@ impl ModelManager { ...@@ -102,7 +102,7 @@ impl ModelManager {
model: &str, model: &str,
engine: OpenAIChatCompletionsStreamingEngine, engine: OpenAIChatCompletionsStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write().unwrap(); let mut clients = self.chat_completion_engines.write();
clients.add(model, engine) clients.add(model, engine)
} }
...@@ -111,22 +111,22 @@ impl ModelManager { ...@@ -111,22 +111,22 @@ impl ModelManager {
model: &str, model: &str,
engine: OpenAIEmbeddingsStreamingEngine, engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write().unwrap(); let mut clients = self.embeddings_engines.write();
clients.add(model, engine) clients.add(model, engine)
} }
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> { pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write().unwrap(); let mut clients = self.completion_engines.write();
clients.remove(model) clients.remove(model)
} }
pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> { pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write().unwrap(); let mut clients = self.chat_completion_engines.write();
clients.remove(model) clients.remove(model)
} }
pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> { pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write().unwrap(); let mut clients = self.embeddings_engines.write();
clients.remove(model) clients.remove(model)
} }
...@@ -136,7 +136,6 @@ impl ModelManager { ...@@ -136,7 +136,6 @@ impl ModelManager {
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> { ) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.embeddings_engines self.embeddings_engines
.read() .read()
.unwrap()
.get(model) .get(model)
.cloned() .cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string())) .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
...@@ -148,7 +147,6 @@ impl ModelManager { ...@@ -148,7 +147,6 @@ impl ModelManager {
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> { ) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.completion_engines self.completion_engines
.read() .read()
.unwrap()
.get(model) .get(model)
.cloned() .cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string())) .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
...@@ -160,7 +158,6 @@ impl ModelManager { ...@@ -160,7 +158,6 @@ impl ModelManager {
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> { ) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
self.chat_completion_engines self.chat_completion_engines
.read() .read()
.unwrap()
.get(model) .get(model)
.cloned() .cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string())) .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment