Unverified Commit 8065fe12 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: Refactor discovery ModelManager to use `parking_lot::RwLock` (#2902)


Signed-off-by: default avatarPaul Hendricks <phendricks@nvidia.com>
parent 85b3dd44
......@@ -3,10 +3,10 @@
use std::{
collections::{HashMap, HashSet},
sync::{Arc, RwLock},
sync::Arc,
};
use parking_lot::Mutex;
use parking_lot::{Mutex, RwLock};
use dynamo_runtime::component::Component;
use dynamo_runtime::prelude::DistributedRuntimeProvider;
......@@ -64,8 +64,8 @@ impl ModelManager {
}
pub fn has_model_any(&self, model: &str) -> bool {
self.chat_completion_engines.read().unwrap().contains(model)
|| self.completion_engines.read().unwrap().contains(model)
self.chat_completion_engines.read().contains(model)
|| self.completion_engines.read().contains(model)
}
pub fn model_display_names(&self) -> HashSet<String> {
......@@ -77,15 +77,15 @@ impl ModelManager {
}
pub fn list_chat_completions_models(&self) -> Vec<String> {
self.chat_completion_engines.read().unwrap().list()
self.chat_completion_engines.read().list()
}
pub fn list_completions_models(&self) -> Vec<String> {
self.completion_engines.read().unwrap().list()
self.completion_engines.read().list()
}
pub fn list_embeddings_models(&self) -> Vec<String> {
self.embeddings_engines.read().unwrap().list()
self.embeddings_engines.read().list()
}
pub fn add_completions_model(
......@@ -93,7 +93,7 @@ impl ModelManager {
model: &str,
engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write().unwrap();
let mut clients = self.completion_engines.write();
clients.add(model, engine)
}
......@@ -102,7 +102,7 @@ impl ModelManager {
model: &str,
engine: OpenAIChatCompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write().unwrap();
let mut clients = self.chat_completion_engines.write();
clients.add(model, engine)
}
......@@ -111,22 +111,22 @@ impl ModelManager {
model: &str,
engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write().unwrap();
let mut clients = self.embeddings_engines.write();
clients.add(model, engine)
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write().unwrap();
let mut clients = self.completion_engines.write();
clients.remove(model)
}
pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write().unwrap();
let mut clients = self.chat_completion_engines.write();
clients.remove(model)
}
pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write().unwrap();
let mut clients = self.embeddings_engines.write();
clients.remove(model)
}
......@@ -136,7 +136,6 @@ impl ModelManager {
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.embeddings_engines
.read()
.unwrap()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
......@@ -148,7 +147,6 @@ impl ModelManager {
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.completion_engines
.read()
.unwrap()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
......@@ -160,7 +158,6 @@ impl ModelManager {
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
self.chat_completion_engines
.read()
.unwrap()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment