"components/backends/sglang/vscode:/vscode.git/clone" did not exist on "3ea3d59b914ad91a95b128c363116797b5f4669f"
Unverified Commit 27fad26f authored by Olga Andreeva's avatar Olga Andreeva Committed by GitHub
Browse files

refactor: Split ModelType to ModelInput for request and response type;...


refactor: Split ModelType to ModelInput for request and response type; ModelType for the supported workloads (#2714)
Signed-off-by: default avatarGuan Luo <gluo@nvidia.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
Co-authored-by: default avatarGuan Luo <gluo@nvidia.com>
Co-authored-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent b97db875
......@@ -220,6 +220,36 @@ pub async fn build_routed_pipeline<Req, Resp>(
busy_threshold: Option<f64>,
chooser: Option<Arc<KvRouter>>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where
Req: Data,
Resp: Data,
OpenAIPreprocessor: Operator<
Context<Req>,
Pin<Box<dyn AsyncEngineStream<Annotated<Resp>>>>,
Context<PreprocessedRequest>,
Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
>,
{
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?;
build_routed_pipeline_with_preprocessor(
card,
client,
router_mode,
busy_threshold,
chooser,
preprocessor,
)
.await
}
pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
card: &ModelDeploymentCard,
client: &Client,
router_mode: RouterMode,
busy_threshold: Option<f64>,
chooser: Option<Arc<KvRouter>>,
preprocessor: Arc<OpenAIPreprocessor>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where
Req: Data,
Resp: Data,
......@@ -231,7 +261,7 @@ where
>,
{
let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let preprocessor_op = preprocessor.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let migration = Migration::from_mdc(card.clone()).await?.into_operator();
let router =
......@@ -255,13 +285,13 @@ where
};
let engine = frontend
.link(preprocessor.forward_edge())?
.link(preprocessor_op.forward_edge())?
.link(backend.forward_edge())?
.link(migration.forward_edge())?
.link(service_backend)?
.link(migration.backward_edge())?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(preprocessor_op.backward_edge())?
.link(frontend)?;
Ok(engine)
}
......
......@@ -6,7 +6,7 @@ use std::{future::Future, pin::Pin, sync::Arc};
use crate::{
backend::Backend,
engines::StreamingEngineAdapter,
model_type::ModelType,
model_type::{ModelInput, ModelType},
preprocessor::{BackendOutput, PreprocessedRequest},
types::{
Annotated,
......@@ -55,7 +55,9 @@ pub async fn run(
>::for_engine(engine)?;
if !is_static {
model.attach(&endpoint, ModelType::Chat).await?;
model
.attach(&endpoint, ModelType::Chat, ModelInput::Text)
.await?;
}
let fut_chat = endpoint.endpoint_builder().handler(ingress_chat).start();
......@@ -83,8 +85,13 @@ pub async fn run(
let ingress = Ingress::for_pipeline(pipeline)?;
if !is_static {
model.attach(&endpoint, ModelType::Backend).await?;
// Default to supporting both Chat and Completions endpoints
let model_type = ModelType::Chat | ModelType::Completions;
model
.attach(&endpoint, model_type, ModelInput::Tokens)
.await?;
}
let fut = endpoint.endpoint_builder().handler(ingress).start();
(Box::pin(fut), Some(model.card().clone()))
......
......@@ -10,7 +10,6 @@ use crate::{
entrypoint::{self, EngineConfig, input::common},
http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig,
model_type::ModelType,
namespace::is_global_namespace,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
......@@ -247,23 +246,17 @@ fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
model_type
);
match model_type {
ModelUpdate::Added(model_type) => match model_type {
ModelType::Backend => {
service.enable_model_endpoint(EndpointType::Chat, true);
service.enable_model_endpoint(EndpointType::Completion, true);
ModelUpdate::Added(model_type) => {
// Handle all supported endpoint types, not just the first one
for endpoint_type in model_type.as_endpoint_types() {
service.enable_model_endpoint(endpoint_type, true);
}
_ => {
service.enable_model_endpoint(model_type.as_endpoint_type(), true);
}
},
ModelUpdate::Removed(model_type) => match model_type {
ModelType::Backend => {
service.enable_model_endpoint(EndpointType::Chat, false);
service.enable_model_endpoint(EndpointType::Completion, false);
ModelUpdate::Removed(model_type) => {
// Handle all supported endpoint types, not just the first one
for endpoint_type in model_type.as_endpoint_types() {
service.enable_model_endpoint(endpoint_type, false);
}
_ => {
service.enable_model_endpoint(model_type.as_endpoint_type(), false);
}
},
}
}
......@@ -18,7 +18,7 @@ use crate::discovery::ModelEntry;
use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs;
use crate::model_card::{self, ModelDeploymentCard};
use crate::model_type::ModelType;
use crate::model_type::{ModelInput, ModelType};
use crate::request_template::RequestTemplate;
mod network_name;
......@@ -403,6 +403,7 @@ impl LocalModel {
&mut self,
endpoint: &Endpoint,
model_type: ModelType,
model_input: ModelInput,
) -> anyhow::Result<()> {
// A static component doesn't have an etcd_client because it doesn't need to register
let Some(etcd_client) = endpoint.drt().etcd_client() else {
......@@ -431,6 +432,7 @@ impl LocalModel {
endpoint_id: endpoint.id(),
model_type,
runtime_config: Some(self.runtime_config.clone()),
model_input,
};
etcd_client
.kv_create(
......
......@@ -13,41 +13,107 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use bitflags::bitflags;
use serde::{Deserialize, Serialize};
use std::fmt;
use strum::Display;
#[derive(Copy, Debug, Clone, Display, Serialize, Deserialize, Eq, PartialEq)]
pub enum ModelType {
// Chat Completions API
Chat,
/// Older completions API
Completion,
/// Embeddings API
Embedding,
// Pre-processed requests
Backend,
bitflags! {
/// Represents the set of model capabilities (endpoints) a model can support.
///
/// This type is implemented using `bitflags` instead of a plain `enum`
/// so that multiple capabilities can be combined in a single value:
///
/// - `ModelType::Chat`
/// - `ModelType::Completions`
/// - `ModelType::Embedding`
///
/// For example, a model that supports both chat and completions can be
/// expressed as:
///
/// ```rust
/// use dynamo_llm::model_type::ModelType;
/// let mt = ModelType::Chat | ModelType::Completions;
/// assert!(mt.supports_chat());
/// assert!(mt.supports_completions());
/// ```
///
/// Using bitflags avoids deep branching on a single enum variant,
/// simplifies checks like `supports_chat()`, and enables efficient,
/// type-safe combinations of multiple endpoint types within a single byte.
#[derive(Copy, Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelType: u8 {
const Chat = 1 << 0;
const Completions = 1 << 1;
const Embedding = 1 << 2;
}
}
impl ModelType {
pub fn as_str(&self) -> &str {
match self {
Self::Chat => "chat",
Self::Completion => "completion",
Self::Embedding => "embedding",
Self::Backend => "backend",
pub fn as_str(&self) -> String {
self.as_vec().join(",")
}
pub fn supports_chat(&self) -> bool {
self.contains(ModelType::Chat)
}
pub fn supports_completions(&self) -> bool {
self.contains(ModelType::Completions)
}
pub fn supports_embedding(&self) -> bool {
self.contains(ModelType::Embedding)
}
pub fn all() -> Vec<Self> {
vec![Self::Chat, Self::Completion, Self::Embedding, Self::Backend]
pub fn as_vec(&self) -> Vec<&'static str> {
let mut result = Vec::new();
if self.supports_chat() {
result.push("chat");
}
if self.supports_completions() {
result.push("completions");
}
if self.supports_embedding() {
result.push("embedding");
}
result
}
pub fn as_endpoint_type(&self) -> crate::endpoint_type::EndpointType {
/// Returns all endpoint types supported by this model type.
/// This properly handles combinations like Chat | Completions.
pub fn as_endpoint_types(&self) -> Vec<crate::endpoint_type::EndpointType> {
let mut endpoint_types = Vec::new();
if self.contains(Self::Chat) {
endpoint_types.push(crate::endpoint_type::EndpointType::Chat);
}
if self.contains(Self::Completions) {
endpoint_types.push(crate::endpoint_type::EndpointType::Completion);
}
if self.contains(Self::Embedding) {
endpoint_types.push(crate::endpoint_type::EndpointType::Embedding);
}
endpoint_types
}
}
impl fmt::Display for ModelType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Copy, Debug, Clone, Display, Serialize, Deserialize, Eq, PartialEq)]
pub enum ModelInput {
/// Raw text input
Text,
/// Pre-processed input
Tokens,
}
impl ModelInput {
pub fn as_str(&self) -> &str {
match self {
Self::Chat => crate::endpoint_type::EndpointType::Chat,
Self::Completion => crate::endpoint_type::EndpointType::Completion,
Self::Embedding => crate::endpoint_type::EndpointType::Embedding,
Self::Backend => panic!("Backend model type does not map to an endpoint type"),
Self::Text => "text",
Self::Tokens => "tokens",
}
}
}
......@@ -99,9 +99,18 @@ pub struct OpenAIPreprocessor {
impl OpenAIPreprocessor {
pub async fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum();
let formatter = PromptFormatter::from_mdc(mdc.clone()).await?;
let PromptFormatter::OAI(formatter) = formatter;
match formatter {
PromptFormatter::OAI(formatter) => Self::new_with_formatter(mdc, formatter).await,
}
}
pub async fn new_with_formatter(
mdc: ModelDeploymentCard,
formatter: Arc<dyn OAIPromptFormatter>,
) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum();
let tokenizer = match &mdc.tokenizer {
Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?,
Some(TokenizerKind::GGUF(tokenizer)) => {
......@@ -129,7 +138,6 @@ impl OpenAIPreprocessor {
mdcsum,
}))
}
/// Encode a string to it's tokens
pub fn tokenize(&self, s: &str) -> anyhow::Result<Encoding> {
self.tokenizer.encode(s)
......
......@@ -92,3 +92,37 @@ pub trait OAIPromptFormatter: Send + Sync + 'static {
pub enum PromptFormatter {
OAI(Arc<dyn OAIPromptFormatter>),
}
// No-op formatter: used for models without chat_template
#[derive(Debug, Default)]
pub struct NoOpFormatter;
impl OAIPromptFormatter for NoOpFormatter {
fn supports_add_generation_prompt(&self) -> bool {
false
}
fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String> {
let messages = req.messages();
let first_message = messages
.get_item_by_index(0)
.map_err(|_| anyhow::Error::msg("No message at index 0 or messages array is empty"))?;
let content = first_message
.get_attr("content")
.map_err(|_| anyhow::Error::msg("First message has no 'content' field"))?;
let content_str = content
.as_str()
.ok_or_else(|| anyhow::Error::msg("Message content is not a string"))?
.to_string();
Ok(content_str)
}
}
impl PromptFormatter {
pub fn no_op() -> Self {
Self::OAI(Arc::new(NoOpFormatter))
}
}
......@@ -6,7 +6,7 @@
use dynamo_llm::{
discovery::ModelEntry,
model_type::ModelType,
model_type::{ModelInput, ModelType},
namespace::{GLOBAL_NAMESPACE, is_global_namespace},
};
use dynamo_runtime::protocols::EndpointId;
......@@ -18,6 +18,7 @@ fn create_test_model_entry(
component: &str,
endpoint_name: &str,
model_type: ModelType,
model_input: ModelInput,
) -> ModelEntry {
ModelEntry {
name: name.to_string(),
......@@ -27,6 +28,7 @@ fn create_test_model_entry(
name: endpoint_name.to_string(),
},
model_type,
model_input,
runtime_config: None,
}
}
......@@ -41,6 +43,7 @@ fn test_namespace_filtering_behavior() {
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"model-2",
......@@ -48,13 +51,15 @@ fn test_namespace_filtering_behavior() {
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"model-3",
"dynamo",
"backend",
"generate",
ModelType::Completion,
ModelType::Completions,
ModelInput::Tokens,
),
create_test_model_entry(
"model-4",
......@@ -62,6 +67,7 @@ fn test_namespace_filtering_behavior() {
"backend",
"generate",
ModelType::Embedding,
ModelInput::Tokens,
),
];
......@@ -165,6 +171,7 @@ fn test_model_discovery_scoping_scenarios() {
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"mistral-7b",
......@@ -172,6 +179,7 @@ fn test_model_discovery_scoping_scenarios() {
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"gpt-3.5",
......@@ -179,8 +187,16 @@ fn test_model_discovery_scoping_scenarios() {
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"claude-3",
"dynamo",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry("claude-3", "dynamo", "backend", "generate", ModelType::Chat),
];
let visible_models: Vec<&ModelEntry> = available_models
......@@ -228,14 +244,29 @@ fn test_namespace_boundary_conditions() {
// Test edge cases and boundary conditions for namespace handling
let test_models = vec![
create_test_model_entry("model-1", "", "backend", "generate", ModelType::Chat), // Empty namespace
create_test_model_entry("model-2", "dynamo", "backend", "generate", ModelType::Chat), // Global namespace
create_test_model_entry(
"model-1",
"",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
), // Empty namespace
create_test_model_entry(
"model-2",
"dynamo",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
), // Global namespace
create_test_model_entry(
"model-3",
"ns-with-special-chars_123",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
];
......
......@@ -3,7 +3,7 @@
use dynamo_llm::{
discovery::ModelEntry,
model_type::ModelType,
model_type::{ModelInput, ModelType},
namespace::{GLOBAL_NAMESPACE, is_global_namespace},
};
use dynamo_runtime::protocols::EndpointId;
......@@ -53,6 +53,7 @@ fn create_test_model_entry(
component: &str,
endpoint_name: &str,
model_type: ModelType,
model_input: ModelInput,
) -> ModelEntry {
ModelEntry {
name: name.to_string(),
......@@ -62,6 +63,7 @@ fn create_test_model_entry(
name: endpoint_name.to_string(),
},
model_type,
model_input,
runtime_config: None,
}
}
......@@ -75,6 +77,7 @@ fn test_model_entry_creation_with_different_namespaces() {
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
);
assert_eq!(model_vllm.name, "test-model-1");
......@@ -82,6 +85,7 @@ fn test_model_entry_creation_with_different_namespaces() {
assert_eq!(model_vllm.endpoint_id.component, "backend");
assert_eq!(model_vllm.endpoint_id.name, "generate");
assert_eq!(model_vllm.model_type, ModelType::Chat);
assert_eq!(model_vllm.model_input, ModelInput::Tokens);
// Test creating ModelEntry with global namespace
let model_global = create_test_model_entry(
......@@ -89,14 +93,16 @@ fn test_model_entry_creation_with_different_namespaces() {
"dynamo",
"frontend",
"http",
ModelType::Completion,
ModelType::Completions,
ModelInput::Text,
);
assert_eq!(model_global.name, "test-model-2");
assert_eq!(model_global.endpoint_id.namespace, "dynamo");
assert_eq!(model_global.endpoint_id.component, "frontend");
assert_eq!(model_global.endpoint_id.name, "http");
assert_eq!(model_global.model_type, ModelType::Completion);
assert_eq!(model_global.model_type, ModelType::Completions);
assert_eq!(model_global.model_input, ModelInput::Text);
}
#[test]
......@@ -109,6 +115,7 @@ fn test_namespace_filtering_logic() {
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"model-2",
......@@ -116,9 +123,24 @@ fn test_namespace_filtering_logic() {
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"model-3",
"dynamo",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"model-4",
"",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry("model-3", "dynamo", "backend", "generate", ModelType::Chat),
create_test_model_entry("model-4", "", "backend", "generate", ModelType::Chat),
];
// Test filtering for specific namespace "vllm-agg"
......@@ -173,6 +195,7 @@ fn test_model_entry_serialization() {
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
);
// Serialize to JSON
......@@ -196,6 +219,7 @@ fn test_model_entry_serialization() {
);
assert_eq!(deserialized.endpoint_id.name, model.endpoint_id.name);
assert_eq!(deserialized.model_type, model.model_type);
assert_eq!(deserialized.model_input, model.model_input);
}
#[test]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment