Unverified Commit 27fad26f authored by Olga Andreeva's avatar Olga Andreeva Committed by GitHub
Browse files

refactor: Split ModelType to ModelInput for request and response type;...


refactor: Split ModelType to ModelInput for request and response type; ModelType for the supported workloads (#2714)
Signed-off-by: default avatarGuan Luo <gluo@nvidia.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
Co-authored-by: default avatarGuan Luo <gluo@nvidia.com>
Co-authored-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent b97db875
...@@ -220,6 +220,36 @@ pub async fn build_routed_pipeline<Req, Resp>( ...@@ -220,6 +220,36 @@ pub async fn build_routed_pipeline<Req, Resp>(
busy_threshold: Option<f64>, busy_threshold: Option<f64>,
chooser: Option<Arc<KvRouter>>, chooser: Option<Arc<KvRouter>>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>> ) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where
Req: Data,
Resp: Data,
OpenAIPreprocessor: Operator<
Context<Req>,
Pin<Box<dyn AsyncEngineStream<Annotated<Resp>>>>,
Context<PreprocessedRequest>,
Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
>,
{
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?;
build_routed_pipeline_with_preprocessor(
card,
client,
router_mode,
busy_threshold,
chooser,
preprocessor,
)
.await
}
pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
card: &ModelDeploymentCard,
client: &Client,
router_mode: RouterMode,
busy_threshold: Option<f64>,
chooser: Option<Arc<KvRouter>>,
preprocessor: Arc<OpenAIPreprocessor>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where where
Req: Data, Req: Data,
Resp: Data, Resp: Data,
...@@ -231,7 +261,7 @@ where ...@@ -231,7 +261,7 @@ where
>, >,
{ {
let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new(); let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator(); let preprocessor_op = preprocessor.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator(); let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let migration = Migration::from_mdc(card.clone()).await?.into_operator(); let migration = Migration::from_mdc(card.clone()).await?.into_operator();
let router = let router =
...@@ -255,13 +285,13 @@ where ...@@ -255,13 +285,13 @@ where
}; };
let engine = frontend let engine = frontend
.link(preprocessor.forward_edge())? .link(preprocessor_op.forward_edge())?
.link(backend.forward_edge())? .link(backend.forward_edge())?
.link(migration.forward_edge())? .link(migration.forward_edge())?
.link(service_backend)? .link(service_backend)?
.link(migration.backward_edge())? .link(migration.backward_edge())?
.link(backend.backward_edge())? .link(backend.backward_edge())?
.link(preprocessor.backward_edge())? .link(preprocessor_op.backward_edge())?
.link(frontend)?; .link(frontend)?;
Ok(engine) Ok(engine)
} }
......
...@@ -6,7 +6,7 @@ use std::{future::Future, pin::Pin, sync::Arc}; ...@@ -6,7 +6,7 @@ use std::{future::Future, pin::Pin, sync::Arc};
use crate::{ use crate::{
backend::Backend, backend::Backend,
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
model_type::ModelType, model_type::{ModelInput, ModelType},
preprocessor::{BackendOutput, PreprocessedRequest}, preprocessor::{BackendOutput, PreprocessedRequest},
types::{ types::{
Annotated, Annotated,
...@@ -55,7 +55,9 @@ pub async fn run( ...@@ -55,7 +55,9 @@ pub async fn run(
>::for_engine(engine)?; >::for_engine(engine)?;
if !is_static { if !is_static {
model.attach(&endpoint, ModelType::Chat).await?; model
.attach(&endpoint, ModelType::Chat, ModelInput::Text)
.await?;
} }
let fut_chat = endpoint.endpoint_builder().handler(ingress_chat).start(); let fut_chat = endpoint.endpoint_builder().handler(ingress_chat).start();
...@@ -83,8 +85,13 @@ pub async fn run( ...@@ -83,8 +85,13 @@ pub async fn run(
let ingress = Ingress::for_pipeline(pipeline)?; let ingress = Ingress::for_pipeline(pipeline)?;
if !is_static { if !is_static {
model.attach(&endpoint, ModelType::Backend).await?; // Default to supporting both Chat and Completions endpoints
let model_type = ModelType::Chat | ModelType::Completions;
model
.attach(&endpoint, model_type, ModelInput::Tokens)
.await?;
} }
let fut = endpoint.endpoint_builder().handler(ingress).start(); let fut = endpoint.endpoint_builder().handler(ingress).start();
(Box::pin(fut), Some(model.card().clone())) (Box::pin(fut), Some(model.card().clone()))
......
...@@ -10,7 +10,6 @@ use crate::{ ...@@ -10,7 +10,6 @@ use crate::{
entrypoint::{self, EngineConfig, input::common}, entrypoint::{self, EngineConfig, input::common},
http::service::service_v2::{self, HttpService}, http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig, kv_router::KvRouterConfig,
model_type::ModelType,
namespace::is_global_namespace, namespace::is_global_namespace,
types::openai::{ types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
...@@ -247,23 +246,17 @@ fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) { ...@@ -247,23 +246,17 @@ fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
model_type model_type
); );
match model_type { match model_type {
ModelUpdate::Added(model_type) => match model_type { ModelUpdate::Added(model_type) => {
ModelType::Backend => { // Handle all supported endpoint types, not just the first one
service.enable_model_endpoint(EndpointType::Chat, true); for endpoint_type in model_type.as_endpoint_types() {
service.enable_model_endpoint(EndpointType::Completion, true); service.enable_model_endpoint(endpoint_type, true);
} }
_ => { }
service.enable_model_endpoint(model_type.as_endpoint_type(), true); ModelUpdate::Removed(model_type) => {
} // Handle all supported endpoint types, not just the first one
}, for endpoint_type in model_type.as_endpoint_types() {
ModelUpdate::Removed(model_type) => match model_type { service.enable_model_endpoint(endpoint_type, false);
ModelType::Backend => {
service.enable_model_endpoint(EndpointType::Chat, false);
service.enable_model_endpoint(EndpointType::Completion, false);
}
_ => {
service.enable_model_endpoint(model_type.as_endpoint_type(), false);
} }
}, }
} }
} }
...@@ -18,7 +18,7 @@ use crate::discovery::ModelEntry; ...@@ -18,7 +18,7 @@ use crate::discovery::ModelEntry;
use crate::entrypoint::RouterConfig; use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs; use crate::mocker::protocols::MockEngineArgs;
use crate::model_card::{self, ModelDeploymentCard}; use crate::model_card::{self, ModelDeploymentCard};
use crate::model_type::ModelType; use crate::model_type::{ModelInput, ModelType};
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
mod network_name; mod network_name;
...@@ -403,6 +403,7 @@ impl LocalModel { ...@@ -403,6 +403,7 @@ impl LocalModel {
&mut self, &mut self,
endpoint: &Endpoint, endpoint: &Endpoint,
model_type: ModelType, model_type: ModelType,
model_input: ModelInput,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// A static component doesn't have an etcd_client because it doesn't need to register // A static component doesn't have an etcd_client because it doesn't need to register
let Some(etcd_client) = endpoint.drt().etcd_client() else { let Some(etcd_client) = endpoint.drt().etcd_client() else {
...@@ -431,6 +432,7 @@ impl LocalModel { ...@@ -431,6 +432,7 @@ impl LocalModel {
endpoint_id: endpoint.id(), endpoint_id: endpoint.id(),
model_type, model_type,
runtime_config: Some(self.runtime_config.clone()), runtime_config: Some(self.runtime_config.clone()),
model_input,
}; };
etcd_client etcd_client
.kv_create( .kv_create(
......
...@@ -13,41 +13,107 @@ ...@@ -13,41 +13,107 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use bitflags::bitflags;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt;
use strum::Display; use strum::Display;
#[derive(Copy, Debug, Clone, Display, Serialize, Deserialize, Eq, PartialEq)] bitflags! {
pub enum ModelType { /// Represents the set of model capabilities (endpoints) a model can support.
// Chat Completions API ///
Chat, /// This type is implemented using `bitflags` instead of a plain `enum`
/// Older completions API /// so that multiple capabilities can be combined in a single value:
Completion, ///
/// Embeddings API /// - `ModelType::Chat`
Embedding, /// - `ModelType::Completions`
// Pre-processed requests /// - `ModelType::Embedding`
Backend, ///
/// For example, a model that supports both chat and completions can be
/// expressed as:
///
/// ```rust
/// use dynamo_llm::model_type::ModelType;
/// let mt = ModelType::Chat | ModelType::Completions;
/// assert!(mt.supports_chat());
/// assert!(mt.supports_completions());
/// ```
///
/// Using bitflags avoids deep branching on a single enum variant,
/// simplifies checks like `supports_chat()`, and enables efficient,
/// type-safe combinations of multiple endpoint types within a single byte.
#[derive(Copy, Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelType: u8 {
const Chat = 1 << 0;
const Completions = 1 << 1;
const Embedding = 1 << 2;
}
} }
impl ModelType { impl ModelType {
pub fn as_str(&self) -> &str { pub fn as_str(&self) -> String {
match self { self.as_vec().join(",")
Self::Chat => "chat", }
Self::Completion => "completion",
Self::Embedding => "embedding", pub fn supports_chat(&self) -> bool {
Self::Backend => "backend", self.contains(ModelType::Chat)
}
pub fn supports_completions(&self) -> bool {
self.contains(ModelType::Completions)
}
pub fn supports_embedding(&self) -> bool {
self.contains(ModelType::Embedding)
}
pub fn as_vec(&self) -> Vec<&'static str> {
let mut result = Vec::new();
if self.supports_chat() {
result.push("chat");
}
if self.supports_completions() {
result.push("completions");
}
if self.supports_embedding() {
result.push("embedding");
} }
result
} }
pub fn all() -> Vec<Self> { /// Returns all endpoint types supported by this model type.
vec![Self::Chat, Self::Completion, Self::Embedding, Self::Backend] /// This properly handles combinations like Chat | Completions.
pub fn as_endpoint_types(&self) -> Vec<crate::endpoint_type::EndpointType> {
let mut endpoint_types = Vec::new();
if self.contains(Self::Chat) {
endpoint_types.push(crate::endpoint_type::EndpointType::Chat);
}
if self.contains(Self::Completions) {
endpoint_types.push(crate::endpoint_type::EndpointType::Completion);
}
if self.contains(Self::Embedding) {
endpoint_types.push(crate::endpoint_type::EndpointType::Embedding);
}
endpoint_types
} }
}
pub fn as_endpoint_type(&self) -> crate::endpoint_type::EndpointType { impl fmt::Display for ModelType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Copy, Debug, Clone, Display, Serialize, Deserialize, Eq, PartialEq)]
pub enum ModelInput {
/// Raw text input
Text,
/// Pre-processed input
Tokens,
}
impl ModelInput {
pub fn as_str(&self) -> &str {
match self { match self {
Self::Chat => crate::endpoint_type::EndpointType::Chat, Self::Text => "text",
Self::Completion => crate::endpoint_type::EndpointType::Completion, Self::Tokens => "tokens",
Self::Embedding => crate::endpoint_type::EndpointType::Embedding,
Self::Backend => panic!("Backend model type does not map to an endpoint type"),
} }
} }
} }
...@@ -99,9 +99,18 @@ pub struct OpenAIPreprocessor { ...@@ -99,9 +99,18 @@ pub struct OpenAIPreprocessor {
impl OpenAIPreprocessor { impl OpenAIPreprocessor {
pub async fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> { pub async fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum();
let formatter = PromptFormatter::from_mdc(mdc.clone()).await?; let formatter = PromptFormatter::from_mdc(mdc.clone()).await?;
let PromptFormatter::OAI(formatter) = formatter; match formatter {
PromptFormatter::OAI(formatter) => Self::new_with_formatter(mdc, formatter).await,
}
}
pub async fn new_with_formatter(
mdc: ModelDeploymentCard,
formatter: Arc<dyn OAIPromptFormatter>,
) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum();
let tokenizer = match &mdc.tokenizer { let tokenizer = match &mdc.tokenizer {
Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?, Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?,
Some(TokenizerKind::GGUF(tokenizer)) => { Some(TokenizerKind::GGUF(tokenizer)) => {
...@@ -129,7 +138,6 @@ impl OpenAIPreprocessor { ...@@ -129,7 +138,6 @@ impl OpenAIPreprocessor {
mdcsum, mdcsum,
})) }))
} }
/// Encode a string to it's tokens /// Encode a string to it's tokens
pub fn tokenize(&self, s: &str) -> anyhow::Result<Encoding> { pub fn tokenize(&self, s: &str) -> anyhow::Result<Encoding> {
self.tokenizer.encode(s) self.tokenizer.encode(s)
......
...@@ -92,3 +92,37 @@ pub trait OAIPromptFormatter: Send + Sync + 'static { ...@@ -92,3 +92,37 @@ pub trait OAIPromptFormatter: Send + Sync + 'static {
pub enum PromptFormatter { pub enum PromptFormatter {
OAI(Arc<dyn OAIPromptFormatter>), OAI(Arc<dyn OAIPromptFormatter>),
} }
// No-op formatter: used for models without chat_template
#[derive(Debug, Default)]
pub struct NoOpFormatter;
impl OAIPromptFormatter for NoOpFormatter {
fn supports_add_generation_prompt(&self) -> bool {
false
}
fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String> {
let messages = req.messages();
let first_message = messages
.get_item_by_index(0)
.map_err(|_| anyhow::Error::msg("No message at index 0 or messages array is empty"))?;
let content = first_message
.get_attr("content")
.map_err(|_| anyhow::Error::msg("First message has no 'content' field"))?;
let content_str = content
.as_str()
.ok_or_else(|| anyhow::Error::msg("Message content is not a string"))?
.to_string();
Ok(content_str)
}
}
impl PromptFormatter {
pub fn no_op() -> Self {
Self::OAI(Arc::new(NoOpFormatter))
}
}
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
use dynamo_llm::{ use dynamo_llm::{
discovery::ModelEntry, discovery::ModelEntry,
model_type::ModelType, model_type::{ModelInput, ModelType},
namespace::{GLOBAL_NAMESPACE, is_global_namespace}, namespace::{GLOBAL_NAMESPACE, is_global_namespace},
}; };
use dynamo_runtime::protocols::EndpointId; use dynamo_runtime::protocols::EndpointId;
...@@ -18,6 +18,7 @@ fn create_test_model_entry( ...@@ -18,6 +18,7 @@ fn create_test_model_entry(
component: &str, component: &str,
endpoint_name: &str, endpoint_name: &str,
model_type: ModelType, model_type: ModelType,
model_input: ModelInput,
) -> ModelEntry { ) -> ModelEntry {
ModelEntry { ModelEntry {
name: name.to_string(), name: name.to_string(),
...@@ -27,6 +28,7 @@ fn create_test_model_entry( ...@@ -27,6 +28,7 @@ fn create_test_model_entry(
name: endpoint_name.to_string(), name: endpoint_name.to_string(),
}, },
model_type, model_type,
model_input,
runtime_config: None, runtime_config: None,
} }
} }
...@@ -41,6 +43,7 @@ fn test_namespace_filtering_behavior() { ...@@ -41,6 +43,7 @@ fn test_namespace_filtering_behavior() {
"backend", "backend",
"generate", "generate",
ModelType::Chat, ModelType::Chat,
ModelInput::Tokens,
), ),
create_test_model_entry( create_test_model_entry(
"model-2", "model-2",
...@@ -48,13 +51,15 @@ fn test_namespace_filtering_behavior() { ...@@ -48,13 +51,15 @@ fn test_namespace_filtering_behavior() {
"backend", "backend",
"generate", "generate",
ModelType::Chat, ModelType::Chat,
ModelInput::Tokens,
), ),
create_test_model_entry( create_test_model_entry(
"model-3", "model-3",
"dynamo", "dynamo",
"backend", "backend",
"generate", "generate",
ModelType::Completion, ModelType::Completions,
ModelInput::Tokens,
), ),
create_test_model_entry( create_test_model_entry(
"model-4", "model-4",
...@@ -62,6 +67,7 @@ fn test_namespace_filtering_behavior() { ...@@ -62,6 +67,7 @@ fn test_namespace_filtering_behavior() {
"backend", "backend",
"generate", "generate",
ModelType::Embedding, ModelType::Embedding,
ModelInput::Tokens,
), ),
]; ];
...@@ -165,6 +171,7 @@ fn test_model_discovery_scoping_scenarios() { ...@@ -165,6 +171,7 @@ fn test_model_discovery_scoping_scenarios() {
"backend", "backend",
"generate", "generate",
ModelType::Chat, ModelType::Chat,
ModelInput::Tokens,
), ),
create_test_model_entry( create_test_model_entry(
"mistral-7b", "mistral-7b",
...@@ -172,6 +179,7 @@ fn test_model_discovery_scoping_scenarios() { ...@@ -172,6 +179,7 @@ fn test_model_discovery_scoping_scenarios() {
"backend", "backend",
"generate", "generate",
ModelType::Chat, ModelType::Chat,
ModelInput::Tokens,
), ),
create_test_model_entry( create_test_model_entry(
"gpt-3.5", "gpt-3.5",
...@@ -179,8 +187,16 @@ fn test_model_discovery_scoping_scenarios() { ...@@ -179,8 +187,16 @@ fn test_model_discovery_scoping_scenarios() {
"backend", "backend",
"generate", "generate",
ModelType::Chat, ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"claude-3",
"dynamo",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
), ),
create_test_model_entry("claude-3", "dynamo", "backend", "generate", ModelType::Chat),
]; ];
let visible_models: Vec<&ModelEntry> = available_models let visible_models: Vec<&ModelEntry> = available_models
...@@ -228,14 +244,29 @@ fn test_namespace_boundary_conditions() { ...@@ -228,14 +244,29 @@ fn test_namespace_boundary_conditions() {
// Test edge cases and boundary conditions for namespace handling // Test edge cases and boundary conditions for namespace handling
let test_models = vec![ let test_models = vec![
create_test_model_entry("model-1", "", "backend", "generate", ModelType::Chat), // Empty namespace create_test_model_entry(
create_test_model_entry("model-2", "dynamo", "backend", "generate", ModelType::Chat), // Global namespace "model-1",
"",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
), // Empty namespace
create_test_model_entry(
"model-2",
"dynamo",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
), // Global namespace
create_test_model_entry( create_test_model_entry(
"model-3", "model-3",
"ns-with-special-chars_123", "ns-with-special-chars_123",
"backend", "backend",
"generate", "generate",
ModelType::Chat, ModelType::Chat,
ModelInput::Tokens,
), ),
]; ];
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
use dynamo_llm::{ use dynamo_llm::{
discovery::ModelEntry, discovery::ModelEntry,
model_type::ModelType, model_type::{ModelInput, ModelType},
namespace::{GLOBAL_NAMESPACE, is_global_namespace}, namespace::{GLOBAL_NAMESPACE, is_global_namespace},
}; };
use dynamo_runtime::protocols::EndpointId; use dynamo_runtime::protocols::EndpointId;
...@@ -53,6 +53,7 @@ fn create_test_model_entry( ...@@ -53,6 +53,7 @@ fn create_test_model_entry(
component: &str, component: &str,
endpoint_name: &str, endpoint_name: &str,
model_type: ModelType, model_type: ModelType,
model_input: ModelInput,
) -> ModelEntry { ) -> ModelEntry {
ModelEntry { ModelEntry {
name: name.to_string(), name: name.to_string(),
...@@ -62,6 +63,7 @@ fn create_test_model_entry( ...@@ -62,6 +63,7 @@ fn create_test_model_entry(
name: endpoint_name.to_string(), name: endpoint_name.to_string(),
}, },
model_type, model_type,
model_input,
runtime_config: None, runtime_config: None,
} }
} }
...@@ -75,6 +77,7 @@ fn test_model_entry_creation_with_different_namespaces() { ...@@ -75,6 +77,7 @@ fn test_model_entry_creation_with_different_namespaces() {
"backend", "backend",
"generate", "generate",
ModelType::Chat, ModelType::Chat,
ModelInput::Tokens,
); );
assert_eq!(model_vllm.name, "test-model-1"); assert_eq!(model_vllm.name, "test-model-1");
...@@ -82,6 +85,7 @@ fn test_model_entry_creation_with_different_namespaces() { ...@@ -82,6 +85,7 @@ fn test_model_entry_creation_with_different_namespaces() {
assert_eq!(model_vllm.endpoint_id.component, "backend"); assert_eq!(model_vllm.endpoint_id.component, "backend");
assert_eq!(model_vllm.endpoint_id.name, "generate"); assert_eq!(model_vllm.endpoint_id.name, "generate");
assert_eq!(model_vllm.model_type, ModelType::Chat); assert_eq!(model_vllm.model_type, ModelType::Chat);
assert_eq!(model_vllm.model_input, ModelInput::Tokens);
// Test creating ModelEntry with global namespace // Test creating ModelEntry with global namespace
let model_global = create_test_model_entry( let model_global = create_test_model_entry(
...@@ -89,14 +93,16 @@ fn test_model_entry_creation_with_different_namespaces() { ...@@ -89,14 +93,16 @@ fn test_model_entry_creation_with_different_namespaces() {
"dynamo", "dynamo",
"frontend", "frontend",
"http", "http",
ModelType::Completion, ModelType::Completions,
ModelInput::Text,
); );
assert_eq!(model_global.name, "test-model-2"); assert_eq!(model_global.name, "test-model-2");
assert_eq!(model_global.endpoint_id.namespace, "dynamo"); assert_eq!(model_global.endpoint_id.namespace, "dynamo");
assert_eq!(model_global.endpoint_id.component, "frontend"); assert_eq!(model_global.endpoint_id.component, "frontend");
assert_eq!(model_global.endpoint_id.name, "http"); assert_eq!(model_global.endpoint_id.name, "http");
assert_eq!(model_global.model_type, ModelType::Completion); assert_eq!(model_global.model_type, ModelType::Completions);
assert_eq!(model_global.model_input, ModelInput::Text);
} }
#[test] #[test]
...@@ -109,6 +115,7 @@ fn test_namespace_filtering_logic() { ...@@ -109,6 +115,7 @@ fn test_namespace_filtering_logic() {
"backend", "backend",
"generate", "generate",
ModelType::Chat, ModelType::Chat,
ModelInput::Tokens,
), ),
create_test_model_entry( create_test_model_entry(
"model-2", "model-2",
...@@ -116,9 +123,24 @@ fn test_namespace_filtering_logic() { ...@@ -116,9 +123,24 @@ fn test_namespace_filtering_logic() {
"backend", "backend",
"generate", "generate",
ModelType::Chat, ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"model-3",
"dynamo",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"model-4",
"",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
), ),
create_test_model_entry("model-3", "dynamo", "backend", "generate", ModelType::Chat),
create_test_model_entry("model-4", "", "backend", "generate", ModelType::Chat),
]; ];
// Test filtering for specific namespace "vllm-agg" // Test filtering for specific namespace "vllm-agg"
...@@ -173,6 +195,7 @@ fn test_model_entry_serialization() { ...@@ -173,6 +195,7 @@ fn test_model_entry_serialization() {
"backend", "backend",
"generate", "generate",
ModelType::Chat, ModelType::Chat,
ModelInput::Tokens,
); );
// Serialize to JSON // Serialize to JSON
...@@ -196,6 +219,7 @@ fn test_model_entry_serialization() { ...@@ -196,6 +219,7 @@ fn test_model_entry_serialization() {
); );
assert_eq!(deserialized.endpoint_id.name, model.endpoint_id.name); assert_eq!(deserialized.endpoint_id.name, model.endpoint_id.name);
assert_eq!(deserialized.model_type, model.model_type); assert_eq!(deserialized.model_type, model.model_type);
assert_eq!(deserialized.model_input, model.model_input);
} }
#[test] #[test]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment