Unverified Commit 27fad26f authored by Olga Andreeva's avatar Olga Andreeva Committed by GitHub
Browse files

refactor: Split ModelType to ModelInput for request and response type;...


refactor: Split ModelType to ModelInput for request and response type; ModelType for the supported workloads (#2714)
Signed-off-by: default avatarGuan Luo <gluo@nvidia.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
Co-authored-by: default avatarGuan Luo <gluo@nvidia.com>
Co-authored-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent b97db875
......@@ -726,6 +726,9 @@ name = "bitflags"
version = "2.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d"
dependencies = [
"serde",
]
[[package]]
name = "bitstream-io"
......@@ -1985,6 +1988,7 @@ dependencies = [
"async_zmq",
"axum 0.8.4",
"axum-server",
"bitflags 2.9.3",
"blake3",
"bs62",
"bytemuck",
......
......@@ -11,7 +11,7 @@ from typing import Optional
import uvloop
from llama_cpp import Llama
from dynamo.llm import ModelType, register_llm
from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -41,10 +41,10 @@ async def worker(runtime: DistributedRuntime):
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
model_type = ModelType.Chat # llama.cpp does the pre-processing
endpoint = component.endpoint(config.endpoint)
await register_llm(
model_type,
ModelInput.Tokens,
ModelType.Chat,
endpoint,
config.model_path,
config.model_name,
......
......@@ -8,7 +8,7 @@ import sglang as sgl
from sglang.srt.server_args import ServerArgs
from dynamo._core import Endpoint
from dynamo.llm import ModelRuntimeConfig, ModelType, register_llm
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.sglang.args import DynamoArgs
......@@ -26,7 +26,8 @@ async def register_llm_with_runtime_config(
runtime_config = await _get_runtime_config(engine, dynamo_args)
try:
await register_llm(
ModelType.Backend,
ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
endpoint,
server_args.model_path,
server_args.served_model_name,
......
......@@ -22,7 +22,7 @@ from torch.cuda import device_count
from transformers import AutoConfig
import dynamo.nixl_connect as nixl_connect
from dynamo.llm import ModelRuntimeConfig, ModelType, register_llm
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine
......@@ -223,7 +223,8 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params = SamplingParams()
default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None
modelType = ModelType.Backend
model_input = ModelInput.Tokens
model_type = ModelType.Chat | ModelType.Completions
multimodal_processor = None
if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1":
......@@ -234,7 +235,7 @@ async def init(runtime: DistributedRuntime, config: Config):
if modality == "multimodal":
engine_args["skip_tokenizer_init"] = False
modelType = ModelType.Chat
model_input = ModelInput.Text
model_config = AutoConfig.from_pretrained(
config.model_path, trust_remote_code=True
)
......@@ -292,7 +293,8 @@ async def init(runtime: DistributedRuntime, config: Config):
if is_first_worker(config):
# Register the model with runtime config
await register_llm(
modelType,
model_input,
model_type,
endpoint,
config.model_path,
config.served_model_name,
......
......@@ -12,6 +12,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
from dynamo.llm import (
ModelInput,
ModelRuntimeConfig,
ModelType,
ZmqKvEventPublisher,
......@@ -251,7 +252,8 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config.reasoning_parser = config.reasoning_parser
await register_llm(
ModelType.Backend,
ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
generate_endpoint,
config.model,
config.served_model_name,
......
......@@ -16,7 +16,7 @@ The Python file must do three things:
3. Attach a request handler
```
from dynamo.llm import ModelType, register_llm
from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
# 1. Decorate a function to get the runtime
......@@ -29,10 +29,11 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
component = runtime.namespace("namespace").component("component")
await component.create_service()
model_path = "Qwen/Qwen3-0.6B" # or "/data/models/Qwen3-0.6B"
model_type = ModelType.Backend
model_input = ModelInput.Tokens # or ModelInput.Text if engine handles pre-processing
model_type = ModelType.Chat # or ModelType.Chat | ModelType.Completions if model can be deployed on chat and completions endpoints
endpoint = component.endpoint("endpoint")
# Optional last param to register_llm is model_name. If not present derives it from model_path
await register_llm(model_type, endpoint, model_path)
await register_llm(model_input, model_type, endpoint, model_path)
# Initialize your engine here
# engine = ...
......@@ -62,10 +63,13 @@ The `model_path` can be:
- The path to a checkout of a HuggingFace repo - any folder containing safetensor files as well as `config.json`, `tokenizer.json` and `tokenizer_config.json`.
- The path to a GGUF file, if your engine supports that.
The `model_input` can be:
- ModelInput.Tokens. Your engine expects pre-processed input (token IDs). Dynamo handles tokenization and pre-processing.
- ModelInput.Text. Your engine expects raw text input and handles its own tokenization and pre-processing.
The `model_type` can be:
- ModelType.Backend. Dynamo handles pre-processing. Your `generate` method receives a `request` dict containing a `token_ids` array of int. It must return a dict also containing a `token_ids` array and an optional `finish_reason` string.
- ModelType.Chat. Your `generate` method receives a `request` and must return a response dict of type [OpenAI Chat Completion](https://platform.openai.com/docs/api-reference/chat). Your engine handles pre-processing.
- ModelType.Completion. Your `generate` method receives a `request` and must return a response dict of the older [Completions](https://platform.openai.com/docs/api-reference/completions). Your engine handles pre-processing.
- ModelType.Chat. Your `generate` method receives a `request` and must return a response dict of type [OpenAI Chat Completion](https://platform.openai.com/docs/api-reference/chat).
- ModelType.Completions. Your `generate` method receives a `request` and must return a response dict of the older [Completions](https://platform.openai.com/docs/api-reference/completions).
`register_llm` can also take the following kwargs:
- `model_name`: The name to call the model. Your incoming HTTP requests model name must match this. Defaults to the hugging face repo name, the folder name, or the GGUF file name.
......
......@@ -389,7 +389,7 @@ The Python file must do three things:
3. Attach a request handler
```
from dynamo.llm import ModelType, register_llm
from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
# 1. Decorate a function to get the runtime
......@@ -402,10 +402,11 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
component = runtime.namespace("namespace").component("component")
await component.create_service()
model_path = "Qwen/Qwen3-0.6B" # or "/data/models/Qwen3-0.6B"
model_type = ModelType.Backend
model_input = ModelInput.Tokens # or ModelInput.Text if engine handles pre-processing
model_type = ModelType.Chat # or ModelType.Chat | ModelType.Completions if model can be deployed on chat and completions endpoints
endpoint = component.endpoint("endpoint")
# Optional last param to register_llm is model_name. If not present derives it from model_path
await register_llm(model_type, endpoint, model_path)
await register_llm(model_input, model_type, endpoint, model_path)
# Initialize your engine here
# engine = ...
......@@ -435,10 +436,13 @@ The `model_path` can be:
- The path to a checkout of a HuggingFace repo - any folder containing safetensor files as well as `config.json`, `tokenizer.json` and `tokenizer_config.json`.
- The path to a GGUF file, if your engine supports that.
The `model_input` can be:
- ModelInput.Tokens. Your engine expects pre-processed input (token IDs). Dynamo handles tokenization and pre-processing.
- ModelInput.Text. Your engine expects raw text input and handles its own tokenization and pre-processing.
The `model_type` can be:
- ModelType.Backend. Dynamo handles pre-processing. Your `generate` method receives a `request` dict containing a `token_ids` array of int. It must return a dict also containing a `token_ids` array and an optional `finish_reason` string.
- ModelType.Chat. Your `generate` method receives a `request` and must return a response dict of type [OpenAI Chat Completion](https://platform.openai.com/docs/api-reference/chat). Your engine handles pre-processing.
- ModelType.Completion. Your `generate` method receives a `request` and must return a response dict of the older [Completions](https://platform.openai.com/docs/api-reference/completions). Your engine handles pre-processing.
- ModelType.Chat. Your `generate` method receives a `request` and must return a response dict of type [OpenAI Chat Completion](https://platform.openai.com/docs/api-reference/chat).
- ModelType.Completions. Your `generate` method receives a `request` and must return a response dict of the older [Completions](https://platform.openai.com/docs/api-reference/completions).
`register_llm` can also take the following kwargs:
- `model_name`: The name to call the model. Your incoming HTTP requests model name must match this. Defaults to the hugging face repo name, the folder name, or the GGUF file name.
......
......@@ -32,7 +32,7 @@ from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import FlexibleArgumentParser
from dynamo.llm import ModelType, register_llm
from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -321,7 +321,8 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
# Register the endpoint as entrypoint to a model
await register_llm(
ModelType.Chat, # Custom processor is used and this type bypasses SDK processor
ModelInput.Text, # Custom processor is used and this type bypasses SDK processor
ModelType.Chat,
generate_endpoint,
config.model,
config.served_model_name,
......
......@@ -590,6 +590,9 @@ name = "bitflags"
version = "2.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d"
dependencies = [
"serde",
]
[[package]]
name = "bitstream-io"
......@@ -1383,6 +1386,7 @@ dependencies = [
"async_zmq",
"axum",
"axum-server",
"bitflags 2.9.3",
"blake3",
"bs62",
"bytemuck",
......
......@@ -8,7 +8,7 @@
# request via NATS to this python script, which runs sglang.
#
# The key differences between this and `server_sglang_tok.py` are:
# - The `register_llm` function registers us a `Backend` model
# - The `register_llm` function registers us a `Chat` and `Completions` model that accepts `Tokens` input
# - The `generate` function receives a pre-tokenized request and must return token_ids in the response.
#
# Setup a virtualenv with dynamo.llm, dynamo.runtime and sglang[all] installed
......@@ -27,7 +27,7 @@ import sglang
import uvloop
from sglang.srt.server_args import ServerArgs
from dynamo.llm import ModelType, register_llm
from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
......@@ -91,7 +91,12 @@ async def init(runtime: DistributedRuntime, config: Config):
await component.create_service()
endpoint = component.endpoint(config.endpoint)
await register_llm(ModelType.Backend, endpoint, config.model)
await register_llm(
ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
endpoint,
config.model,
)
engine_args = ServerArgs(
model_path=config.model,
......
......@@ -9,7 +9,7 @@
# do the pre/post-processing.
#
# The key differences between this and `server_sglang.py` are:
# - The `register_llm` function registers us a `Chat` model
# - The `register_llm` function registers us a `Chat` and `Completions` model that accepts `Text` input
# - The `generate` function receives a chat completion request and must return matching response
#
# Setup a virtualenv with dynamo.llm, dynamo.runtime and sglang[all] installed
......@@ -31,7 +31,7 @@ from sglang.srt.openai_api.adapter import v1_chat_generate_request
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs
from dynamo.llm import ModelType, register_llm
from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
......@@ -104,7 +104,9 @@ async def init(runtime: DistributedRuntime, config: Config):
await component.create_service()
endpoint = component.endpoint(config.endpoint)
await register_llm(ModelType.Chat, endpoint, config.model)
await register_llm(
ModelInput.Text, ModelType.Chat | ModelType.Completions, endpoint, config.model
)
server_args = ServerArgs(model_path=config.model)
tokenizer_manager, _scheduler_info = _launch_subprocesses(server_args=server_args)
......
......@@ -39,7 +39,7 @@ from vllm.entrypoints.openai.api_server import (
)
from vllm.inputs import TokensPrompt
from dynamo.llm import ModelType, register_llm
from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
......@@ -114,7 +114,12 @@ async def init(runtime: DistributedRuntime, config: Config):
await component.create_service()
endpoint = component.endpoint(config.endpoint)
await register_llm(ModelType.Backend, endpoint, config.model)
await register_llm(
ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
endpoint,
config.model,
)
engine_args = AsyncEngineArgs(
model=config.model,
......
......@@ -109,6 +109,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<context::Context>()?;
m.add_class::<EtcdKvCache>()?;
m.add_class::<ModelType>()?;
m.add_class::<ModelInput>()?;
m.add_class::<llm::kv::ForwardPassMetrics>()?;
m.add_class::<llm::kv::WorkerStats>()?;
m.add_class::<llm::kv::KvStats>()?;
......@@ -141,10 +142,11 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
}
#[pyfunction]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None))]
#[pyo3(signature = (model_input, model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None))]
#[allow(clippy::too_many_arguments)]
fn register_llm<'p>(
py: Python<'p>,
model_input: ModelInput,
model_type: ModelType,
endpoint: Endpoint,
model_path: &str,
......@@ -157,13 +159,13 @@ fn register_llm<'p>(
user_data: Option<&Bound<'p, PyDict>>,
custom_template_path: Option<&str>,
) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type {
ModelType::Chat => llm_rs::model_type::ModelType::Chat,
ModelType::Completion => llm_rs::model_type::ModelType::Completion,
ModelType::Backend => llm_rs::model_type::ModelType::Backend,
ModelType::Embedding => llm_rs::model_type::ModelType::Embedding,
let model_input = match model_input {
ModelInput::Text => llm_rs::model_type::ModelInput::Text,
ModelInput::Tokens => llm_rs::model_type::ModelInput::Tokens,
};
let model_type_obj = model_type.inner;
let inner_path = model_path.to_string();
let model_name = model_name.map(|n| n.to_string());
let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin);
......@@ -205,7 +207,7 @@ fn register_llm<'p>(
let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us
local_model
.attach(&endpoint.inner, model_type_obj)
.attach(&endpoint.inner, model_type_obj, model_input)
.await
.map_err(to_pyerr)?;
......@@ -272,14 +274,44 @@ struct Client {
router: rs::pipeline::PushRouter<serde_json::Value, RsAnnotated<serde_json::Value>>,
}
#[pyclass]
#[derive(Clone, PartialEq)]
struct ModelType {
inner: llm_rs::model_type::ModelType,
}
#[pymethods]
#[allow(non_upper_case_globals)]
impl ModelType {
#[classattr]
const Chat: Self = ModelType {
inner: llm_rs::model_type::ModelType::Chat,
};
#[classattr]
const Completions: Self = ModelType {
inner: llm_rs::model_type::ModelType::Completions,
};
#[classattr]
const Embedding: Self = ModelType {
inner: llm_rs::model_type::ModelType::Embedding,
};
fn __or__(&self, other: &Self) -> Self {
ModelType {
inner: self.inner | other.inner,
}
}
fn __str__(&self) -> String {
self.inner.to_string()
}
}
#[pyclass(eq, eq_int)]
#[derive(Clone, PartialEq)]
#[repr(i32)]
enum ModelType {
Chat = 1,
Completion = 2,
Backend = 3,
Embedding = 4,
enum ModelInput {
Text = 1,
Tokens = 2,
}
#[pymethods]
......
......@@ -843,8 +843,12 @@ class HttpAsyncEngine:
...
class ModelInput:
"""What type of request this model needs: Text or Tokens"""
...
class ModelType:
"""What type of request this model needs: Chat, Component or Backend (pre-processed)"""
"""What type of request this model needs: Chat, Completions or Embedding"""
...
class RouterMode:
......@@ -859,7 +863,19 @@ class KvRouterConfig:
"""Values for KV router"""
...
async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: str, model_name: Optional[str] = None, context_length: Optional[int] = None, kv_cache_block_size: Optional[int] = None, router_mode: Optional[RouterMode] = None, migration_limit: int = 0, runtime_config: Optional[ModelRuntimeConfig] = None, user_data: Optional[dict] = None, custom_template_path: Optional[str] = None) -> None:
async def register_llm(
model_input: ModelInput,
model_type: ModelType,
endpoint: Endpoint,
model_path: str,
model_name: Optional[str] = None,
context_length: Optional[int] = None,
kv_cache_block_size: Optional[int] = None,
migration_limit: int = 0,
router_mode: Optional[RouterMode] = None,
user_data: Optional[Dict[str, Any]] = None,
custom_template_path: Optional[str] = None,
) -> None:
"""Attach the model at path to the given endpoint, and advertise it as model_type"""
...
......
......@@ -28,6 +28,7 @@ from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import KvStats as KvStats
from dynamo._core import ModelInput as ModelInput
from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig
from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores
......
......@@ -89,6 +89,7 @@ xxhash-rust = { workspace = true }
# model_express_common = { version = "0.1.0", optional = true }
akin = "0.4.0"
bitflags = { version = "2.4", features = ["serde"] }
blake3 = "1"
bytemuck = "1.22"
candle-core = { version = "0.9.1" }
......
......@@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize};
use crate::{
local_model::runtime_config::ModelRuntimeConfig,
model_card::{self, ModelDeploymentCard},
model_type::ModelType,
model_type::{ModelInput, ModelType},
};
/// [ModelEntry] contains the information to discover models from the etcd cluster.
......@@ -34,6 +34,11 @@ pub struct ModelEntry {
/// Runtime configuration specific to this model instance
#[serde(default, skip_serializing_if = "Option::is_none")]
pub runtime_config: Option<ModelRuntimeConfig>,
/// Specifies the model input type.
/// `Tokens` for engines that expect pre-processed input.
/// `Text` for engines that take care of pre-processing themselves.
pub model_input: ModelInput,
}
impl ModelEntry {
......@@ -43,7 +48,7 @@ impl ModelEntry {
}
pub fn requires_preprocessing(&self) -> bool {
matches!(self.model_type, ModelType::Backend)
matches!(self.model_input, ModelInput::Tokens)
}
/// Fetch the ModelDeploymentCard from etcd.
......
......@@ -21,8 +21,8 @@ use crate::{
backend::Backend,
entrypoint,
kv_router::KvRouterConfig,
model_type::ModelType,
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest},
model_type::{ModelInput, ModelType},
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
protocols::{
common::llm_backend::EmbeddingsEngineOutput,
openai::{
......@@ -54,8 +54,11 @@ pub struct ModelWatcher {
busy_threshold: Option<f64>,
}
const ALL_MODEL_TYPES: &[ModelType] =
&[ModelType::Chat, ModelType::Completion, ModelType::Embedding];
const ALL_MODEL_TYPES: &[ModelType] = &[
ModelType::Chat,
ModelType::Completions,
ModelType::Embedding,
];
impl ModelWatcher {
pub fn new(
......@@ -236,7 +239,7 @@ impl ModelWatcher {
} else {
for model_type in ALL_MODEL_TYPES {
if ((chat_model_removed && *model_type == ModelType::Chat)
|| (completions_model_removed && *model_type == ModelType::Completion)
|| (completions_model_removed && *model_type == ModelType::Completions)
|| (embeddings_model_removed && *model_type == ModelType::Embedding))
&& let Some(tx) = &self.model_update_tx
{
......@@ -273,10 +276,13 @@ impl ModelWatcher {
}
};
match model_entry.model_type {
ModelType::Backend => {
// A Backend model expects pre-processed requests meaning it's up to us whether we
// handle Chat or Completions requests, so handle both.
if model_entry.model_input == ModelInput::Tokens
&& (model_entry.model_type.supports_chat()
|| model_entry.model_type.supports_completions())
{
// Case 1: Tokens + (Chat OR Completions OR Both)
// A model that expects pre-processed requests meaning it's up to us whether we
// handle Chat or Completions requests, so handle whatever the model supports.
let Some(mut card) = card else {
anyhow::bail!("Missing model deployment card");
......@@ -302,6 +308,8 @@ impl ModelWatcher {
None
};
// Add chat engine only if the model supports chat
if model_entry.model_type.supports_chat() {
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
......@@ -315,8 +323,15 @@ impl ModelWatcher {
.await?;
self.manager
.add_chat_completions_model(&model_entry.name, chat_engine)?;
}
let completions_engine = entrypoint::build_routed_pipeline::<
// Add completions engine only if the model supports completions
if model_entry.model_type.supports_completions() {
let formatter = PromptFormatter::no_op();
let PromptFormatter::OAI(formatter) = formatter;
let preprocessor =
OpenAIPreprocessor::new_with_formatter(card.clone(), formatter).await?;
let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(
......@@ -325,36 +340,44 @@ impl ModelWatcher {
self.router_mode,
self.busy_threshold,
kv_chooser,
preprocessor,
)
.await?;
self.manager
.add_completions_model(&model_entry.name, completions_engine)?;
}
ModelType::Chat => {
} else if model_entry.model_input == ModelInput::Text
&& model_entry.model_type.supports_chat()
{
// Case 3: Text + Chat
let push_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client_with_threshold(
client, Default::default(), self.busy_threshold
client, self.router_mode, self.busy_threshold
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_chat_completions_model(&model_entry.name, engine)?;
}
ModelType::Completion => {
} else if model_entry.model_input == ModelInput::Text
&& model_entry.model_type.supports_completions()
{
// Case 2: Text + Completions
let push_router = PushRouter::<
NvCreateCompletionRequest,
Annotated<NvCreateCompletionResponse>,
>::from_client_with_threshold(
client, Default::default(), self.busy_threshold
client, self.router_mode, self.busy_threshold
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_completions_model(&model_entry.name, engine)?;
}
ModelType::Embedding => {
} else if model_entry.model_input == ModelInput::Tokens
&& model_entry.model_type.supports_embedding()
{
// Case 4: Tokens + Embeddings
let Some(mut card) = card else {
anyhow::bail!("Missing model deployment card for embedding model");
};
......@@ -393,7 +416,14 @@ impl ModelWatcher {
self.manager
.add_embeddings_model(&model_entry.name, embedding_engine)?;
}
} else {
// Reject unsupported combinations
anyhow::bail!(
"Unsupported model configuration: {} with {} input. Supported combinations: \
Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings",
model_entry.model_type,
model_entry.model_input.as_str()
);
}
Ok(())
......
......@@ -6,7 +6,7 @@
//! - Connect it to an Input
pub mod input;
pub use input::build_routed_pipeline;
pub use input::{build_routed_pipeline, build_routed_pipeline_with_preprocessor};
use std::sync::Arc;
......
......@@ -16,7 +16,7 @@ use std::{
pub mod batch;
mod common;
pub use common::build_routed_pipeline;
pub use common::{build_routed_pipeline, build_routed_pipeline_with_preprocessor};
pub mod endpoint;
pub mod grpc;
pub mod http;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment