Unverified Commit 27fad26f authored by Olga Andreeva's avatar Olga Andreeva Committed by GitHub
Browse files

refactor: Split ModelType to ModelInput for request and response type;...


refactor: Split ModelType to ModelInput for request and response type; ModelType for the supported workloads (#2714)
Signed-off-by: default avatarGuan Luo <gluo@nvidia.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
Co-authored-by: default avatarGuan Luo <gluo@nvidia.com>
Co-authored-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent b97db875
...@@ -726,6 +726,9 @@ name = "bitflags" ...@@ -726,6 +726,9 @@ name = "bitflags"
version = "2.9.3" version = "2.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d" checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "bitstream-io" name = "bitstream-io"
...@@ -1985,6 +1988,7 @@ dependencies = [ ...@@ -1985,6 +1988,7 @@ dependencies = [
"async_zmq", "async_zmq",
"axum 0.8.4", "axum 0.8.4",
"axum-server", "axum-server",
"bitflags 2.9.3",
"blake3", "blake3",
"bs62", "bs62",
"bytemuck", "bytemuck",
......
...@@ -11,7 +11,7 @@ from typing import Optional ...@@ -11,7 +11,7 @@ from typing import Optional
import uvloop import uvloop
from llama_cpp import Llama from llama_cpp import Llama
from dynamo.llm import ModelType, register_llm from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -41,10 +41,10 @@ async def worker(runtime: DistributedRuntime): ...@@ -41,10 +41,10 @@ async def worker(runtime: DistributedRuntime):
component = runtime.namespace(config.namespace).component(config.component) component = runtime.namespace(config.namespace).component(config.component)
await component.create_service() await component.create_service()
model_type = ModelType.Chat # llama.cpp does the pre-processing
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
await register_llm( await register_llm(
model_type, ModelInput.Tokens,
ModelType.Chat,
endpoint, endpoint,
config.model_path, config.model_path,
config.model_name, config.model_name,
......
...@@ -8,7 +8,7 @@ import sglang as sgl ...@@ -8,7 +8,7 @@ import sglang as sgl
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from dynamo._core import Endpoint from dynamo._core import Endpoint
from dynamo.llm import ModelRuntimeConfig, ModelType, register_llm from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.sglang.args import DynamoArgs from dynamo.sglang.args import DynamoArgs
...@@ -26,7 +26,8 @@ async def register_llm_with_runtime_config( ...@@ -26,7 +26,8 @@ async def register_llm_with_runtime_config(
runtime_config = await _get_runtime_config(engine, dynamo_args) runtime_config = await _get_runtime_config(engine, dynamo_args)
try: try:
await register_llm( await register_llm(
ModelType.Backend, ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
endpoint, endpoint,
server_args.model_path, server_args.model_path,
server_args.served_model_name, server_args.served_model_name,
......
...@@ -22,7 +22,7 @@ from torch.cuda import device_count ...@@ -22,7 +22,7 @@ from torch.cuda import device_count
from transformers import AutoConfig from transformers import AutoConfig
import dynamo.nixl_connect as nixl_connect import dynamo.nixl_connect as nixl_connect
from dynamo.llm import ModelRuntimeConfig, ModelType, register_llm from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine
...@@ -223,7 +223,8 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -223,7 +223,8 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params = SamplingParams() default_sampling_params = SamplingParams()
default_sampling_params._setup(tokenizer) default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None default_sampling_params.stop = None
modelType = ModelType.Backend model_input = ModelInput.Tokens
model_type = ModelType.Chat | ModelType.Completions
multimodal_processor = None multimodal_processor = None
if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1": if os.getenv("DYNAMO_ENABLE_TEST_LOGITS_PROCESSOR") == "1":
...@@ -234,7 +235,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -234,7 +235,7 @@ async def init(runtime: DistributedRuntime, config: Config):
if modality == "multimodal": if modality == "multimodal":
engine_args["skip_tokenizer_init"] = False engine_args["skip_tokenizer_init"] = False
modelType = ModelType.Chat model_input = ModelInput.Text
model_config = AutoConfig.from_pretrained( model_config = AutoConfig.from_pretrained(
config.model_path, trust_remote_code=True config.model_path, trust_remote_code=True
) )
...@@ -292,7 +293,8 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -292,7 +293,8 @@ async def init(runtime: DistributedRuntime, config: Config):
if is_first_worker(config): if is_first_worker(config):
# Register the model with runtime config # Register the model with runtime config
await register_llm( await register_llm(
modelType, model_input,
model_type,
endpoint, endpoint,
config.model_path, config.model_path,
config.served_model_name, config.served_model_name,
......
...@@ -12,6 +12,7 @@ from vllm.usage.usage_lib import UsageContext ...@@ -12,6 +12,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from dynamo.llm import ( from dynamo.llm import (
ModelInput,
ModelRuntimeConfig, ModelRuntimeConfig,
ModelType, ModelType,
ZmqKvEventPublisher, ZmqKvEventPublisher,
...@@ -251,7 +252,8 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -251,7 +252,8 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config.reasoning_parser = config.reasoning_parser runtime_config.reasoning_parser = config.reasoning_parser
await register_llm( await register_llm(
ModelType.Backend, ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
generate_endpoint, generate_endpoint,
config.model, config.model,
config.served_model_name, config.served_model_name,
......
...@@ -16,7 +16,7 @@ The Python file must do three things: ...@@ -16,7 +16,7 @@ The Python file must do three things:
3. Attach a request handler 3. Attach a request handler
``` ```
from dynamo.llm import ModelType, register_llm from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
# 1. Decorate a function to get the runtime # 1. Decorate a function to get the runtime
...@@ -29,10 +29,11 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker ...@@ -29,10 +29,11 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
component = runtime.namespace("namespace").component("component") component = runtime.namespace("namespace").component("component")
await component.create_service() await component.create_service()
model_path = "Qwen/Qwen3-0.6B" # or "/data/models/Qwen3-0.6B" model_path = "Qwen/Qwen3-0.6B" # or "/data/models/Qwen3-0.6B"
model_type = ModelType.Backend model_input = ModelInput.Tokens # or ModelInput.Text if engine handles pre-processing
model_type = ModelType.Chat # or ModelType.Chat | ModelType.Completions if model can be deployed on chat and completions endpoints
endpoint = component.endpoint("endpoint") endpoint = component.endpoint("endpoint")
# Optional last param to register_llm is model_name. If not present derives it from model_path # Optional last param to register_llm is model_name. If not present derives it from model_path
await register_llm(model_type, endpoint, model_path) await register_llm(model_input, model_type, endpoint, model_path)
# Initialize your engine here # Initialize your engine here
# engine = ... # engine = ...
...@@ -62,10 +63,13 @@ The `model_path` can be: ...@@ -62,10 +63,13 @@ The `model_path` can be:
- The path to a checkout of a HuggingFace repo - any folder containing safetensor files as well as `config.json`, `tokenizer.json` and `tokenizer_config.json`. - The path to a checkout of a HuggingFace repo - any folder containing safetensor files as well as `config.json`, `tokenizer.json` and `tokenizer_config.json`.
- The path to a GGUF file, if your engine supports that. - The path to a GGUF file, if your engine supports that.
The `model_input` can be:
- ModelInput.Tokens. Your engine expects pre-processed input (token IDs). Dynamo handles tokenization and pre-processing.
- ModelInput.Text. Your engine expects raw text input and handles its own tokenization and pre-processing.
The `model_type` can be: The `model_type` can be:
- ModelType.Backend. Dynamo handles pre-processing. Your `generate` method receives a `request` dict containing a `token_ids` array of int. It must return a dict also containing a `token_ids` array and an optional `finish_reason` string. - ModelType.Chat. Your `generate` method receives a `request` and must return a response dict of type [OpenAI Chat Completion](https://platform.openai.com/docs/api-reference/chat).
- ModelType.Chat. Your `generate` method receives a `request` and must return a response dict of type [OpenAI Chat Completion](https://platform.openai.com/docs/api-reference/chat). Your engine handles pre-processing. - ModelType.Completions. Your `generate` method receives a `request` and must return a response dict of the older [Completions](https://platform.openai.com/docs/api-reference/completions).
- ModelType.Completion. Your `generate` method receives a `request` and must return a response dict of the older [Completions](https://platform.openai.com/docs/api-reference/completions). Your engine handles pre-processing.
`register_llm` can also take the following kwargs: `register_llm` can also take the following kwargs:
- `model_name`: The name to call the model. Your incoming HTTP requests model name must match this. Defaults to the hugging face repo name, the folder name, or the GGUF file name. - `model_name`: The name to call the model. Your incoming HTTP requests model name must match this. Defaults to the hugging face repo name, the folder name, or the GGUF file name.
......
...@@ -389,7 +389,7 @@ The Python file must do three things: ...@@ -389,7 +389,7 @@ The Python file must do three things:
3. Attach a request handler 3. Attach a request handler
``` ```
from dynamo.llm import ModelType, register_llm from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
# 1. Decorate a function to get the runtime # 1. Decorate a function to get the runtime
...@@ -402,10 +402,11 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker ...@@ -402,10 +402,11 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
component = runtime.namespace("namespace").component("component") component = runtime.namespace("namespace").component("component")
await component.create_service() await component.create_service()
model_path = "Qwen/Qwen3-0.6B" # or "/data/models/Qwen3-0.6B" model_path = "Qwen/Qwen3-0.6B" # or "/data/models/Qwen3-0.6B"
model_type = ModelType.Backend model_input = ModelInput.Tokens # or ModelInput.Text if engine handles pre-processing
model_type = ModelType.Chat # or ModelType.Chat | ModelType.Completions if model can be deployed on chat and completions endpoints
endpoint = component.endpoint("endpoint") endpoint = component.endpoint("endpoint")
# Optional last param to register_llm is model_name. If not present derives it from model_path # Optional last param to register_llm is model_name. If not present derives it from model_path
await register_llm(model_type, endpoint, model_path) await register_llm(model_input, model_type, endpoint, model_path)
# Initialize your engine here # Initialize your engine here
# engine = ... # engine = ...
...@@ -435,10 +436,13 @@ The `model_path` can be: ...@@ -435,10 +436,13 @@ The `model_path` can be:
- The path to a checkout of a HuggingFace repo - any folder containing safetensor files as well as `config.json`, `tokenizer.json` and `tokenizer_config.json`. - The path to a checkout of a HuggingFace repo - any folder containing safetensor files as well as `config.json`, `tokenizer.json` and `tokenizer_config.json`.
- The path to a GGUF file, if your engine supports that. - The path to a GGUF file, if your engine supports that.
The `model_input` can be:
- ModelInput.Tokens. Your engine expects pre-processed input (token IDs). Dynamo handles tokenization and pre-processing.
- ModelInput.Text. Your engine expects raw text input and handles its own tokenization and pre-processing.
The `model_type` can be: The `model_type` can be:
- ModelType.Backend. Dynamo handles pre-processing. Your `generate` method receives a `request` dict containing a `token_ids` array of int. It must return a dict also containing a `token_ids` array and an optional `finish_reason` string. - ModelType.Chat. Your `generate` method receives a `request` and must return a response dict of type [OpenAI Chat Completion](https://platform.openai.com/docs/api-reference/chat).
- ModelType.Chat. Your `generate` method receives a `request` and must return a response dict of type [OpenAI Chat Completion](https://platform.openai.com/docs/api-reference/chat). Your engine handles pre-processing. - ModelType.Completions. Your `generate` method receives a `request` and must return a response dict of the older [Completions](https://platform.openai.com/docs/api-reference/completions).
- ModelType.Completion. Your `generate` method receives a `request` and must return a response dict of the older [Completions](https://platform.openai.com/docs/api-reference/completions). Your engine handles pre-processing.
`register_llm` can also take the following kwargs: `register_llm` can also take the following kwargs:
- `model_name`: The name to call the model. Your incoming HTTP requests model name must match this. Defaults to the hugging face repo name, the folder name, or the GGUF file name. - `model_name`: The name to call the model. Your incoming HTTP requests model name must match this. Defaults to the hugging face repo name, the folder name, or the GGUF file name.
......
...@@ -32,7 +32,7 @@ from vllm.outputs import RequestOutput ...@@ -32,7 +32,7 @@ from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
from dynamo.llm import ModelType, register_llm from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -321,7 +321,8 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co ...@@ -321,7 +321,8 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
# Register the endpoint as entrypoint to a model # Register the endpoint as entrypoint to a model
await register_llm( await register_llm(
ModelType.Chat, # Custom processor is used and this type bypasses SDK processor ModelInput.Text, # Custom processor is used and this type bypasses SDK processor
ModelType.Chat,
generate_endpoint, generate_endpoint,
config.model, config.model,
config.served_model_name, config.served_model_name,
......
...@@ -590,6 +590,9 @@ name = "bitflags" ...@@ -590,6 +590,9 @@ name = "bitflags"
version = "2.9.3" version = "2.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d" checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "bitstream-io" name = "bitstream-io"
...@@ -1383,6 +1386,7 @@ dependencies = [ ...@@ -1383,6 +1386,7 @@ dependencies = [
"async_zmq", "async_zmq",
"axum", "axum",
"axum-server", "axum-server",
"bitflags 2.9.3",
"blake3", "blake3",
"bs62", "bs62",
"bytemuck", "bytemuck",
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# request via NATS to this python script, which runs sglang. # request via NATS to this python script, which runs sglang.
# #
# The key differences between this and `server_sglang_tok.py` are: # The key differences between this and `server_sglang_tok.py` are:
# - The `register_llm` function registers us a `Backend` model # - The `register_llm` function registers us a `Chat` and `Completions` model that accepts `Tokens` input
# - The `generate` function receives a pre-tokenized request and must return token_ids in the response. # - The `generate` function receives a pre-tokenized request and must return token_ids in the response.
# #
# Setup a virtualenv with dynamo.llm, dynamo.runtime and sglang[all] installed # Setup a virtualenv with dynamo.llm, dynamo.runtime and sglang[all] installed
...@@ -27,7 +27,7 @@ import sglang ...@@ -27,7 +27,7 @@ import sglang
import uvloop import uvloop
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from dynamo.llm import ModelType, register_llm from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate" DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
...@@ -91,7 +91,12 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -91,7 +91,12 @@ async def init(runtime: DistributedRuntime, config: Config):
await component.create_service() await component.create_service()
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
await register_llm(ModelType.Backend, endpoint, config.model) await register_llm(
ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
endpoint,
config.model,
)
engine_args = ServerArgs( engine_args = ServerArgs(
model_path=config.model, model_path=config.model,
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
# do the pre/post-processing. # do the pre/post-processing.
# #
# The key differences between this and `server_sglang.py` are: # The key differences between this and `server_sglang.py` are:
# - The `register_llm` function registers us a `Chat` model # - The `register_llm` function registers us a `Chat` and `Completions` model that accepts `Text` input
# - The `generate` function receives a chat completion request and must return matching response # - The `generate` function receives a chat completion request and must return matching response
# #
# Setup a virtualenv with dynamo.llm, dynamo.runtime and sglang[all] installed # Setup a virtualenv with dynamo.llm, dynamo.runtime and sglang[all] installed
...@@ -31,7 +31,7 @@ from sglang.srt.openai_api.adapter import v1_chat_generate_request ...@@ -31,7 +31,7 @@ from sglang.srt.openai_api.adapter import v1_chat_generate_request
from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from dynamo.llm import ModelType, register_llm from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate" DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
...@@ -104,7 +104,9 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -104,7 +104,9 @@ async def init(runtime: DistributedRuntime, config: Config):
await component.create_service() await component.create_service()
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
await register_llm(ModelType.Chat, endpoint, config.model) await register_llm(
ModelInput.Text, ModelType.Chat | ModelType.Completions, endpoint, config.model
)
server_args = ServerArgs(model_path=config.model) server_args = ServerArgs(model_path=config.model)
tokenizer_manager, _scheduler_info = _launch_subprocesses(server_args=server_args) tokenizer_manager, _scheduler_info = _launch_subprocesses(server_args=server_args)
......
...@@ -39,7 +39,7 @@ from vllm.entrypoints.openai.api_server import ( ...@@ -39,7 +39,7 @@ from vllm.entrypoints.openai.api_server import (
) )
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from dynamo.llm import ModelType, register_llm from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate" DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
...@@ -114,7 +114,12 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -114,7 +114,12 @@ async def init(runtime: DistributedRuntime, config: Config):
await component.create_service() await component.create_service()
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
await register_llm(ModelType.Backend, endpoint, config.model) await register_llm(
ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
endpoint,
config.model,
)
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model=config.model, model=config.model,
......
...@@ -109,6 +109,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -109,6 +109,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<context::Context>()?; m.add_class::<context::Context>()?;
m.add_class::<EtcdKvCache>()?; m.add_class::<EtcdKvCache>()?;
m.add_class::<ModelType>()?; m.add_class::<ModelType>()?;
m.add_class::<ModelInput>()?;
m.add_class::<llm::kv::ForwardPassMetrics>()?; m.add_class::<llm::kv::ForwardPassMetrics>()?;
m.add_class::<llm::kv::WorkerStats>()?; m.add_class::<llm::kv::WorkerStats>()?;
m.add_class::<llm::kv::KvStats>()?; m.add_class::<llm::kv::KvStats>()?;
...@@ -141,10 +142,11 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32) ...@@ -141,10 +142,11 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
} }
#[pyfunction] #[pyfunction]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None))] #[pyo3(signature = (model_input, model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn register_llm<'p>( fn register_llm<'p>(
py: Python<'p>, py: Python<'p>,
model_input: ModelInput,
model_type: ModelType, model_type: ModelType,
endpoint: Endpoint, endpoint: Endpoint,
model_path: &str, model_path: &str,
...@@ -157,13 +159,13 @@ fn register_llm<'p>( ...@@ -157,13 +159,13 @@ fn register_llm<'p>(
user_data: Option<&Bound<'p, PyDict>>, user_data: Option<&Bound<'p, PyDict>>,
custom_template_path: Option<&str>, custom_template_path: Option<&str>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type { let model_input = match model_input {
ModelType::Chat => llm_rs::model_type::ModelType::Chat, ModelInput::Text => llm_rs::model_type::ModelInput::Text,
ModelType::Completion => llm_rs::model_type::ModelType::Completion, ModelInput::Tokens => llm_rs::model_type::ModelInput::Tokens,
ModelType::Backend => llm_rs::model_type::ModelType::Backend,
ModelType::Embedding => llm_rs::model_type::ModelType::Embedding,
}; };
let model_type_obj = model_type.inner;
let inner_path = model_path.to_string(); let inner_path = model_path.to_string();
let model_name = model_name.map(|n| n.to_string()); let model_name = model_name.map(|n| n.to_string());
let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin); let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin);
...@@ -205,7 +207,7 @@ fn register_llm<'p>( ...@@ -205,7 +207,7 @@ fn register_llm<'p>(
let mut local_model = builder.build().await.map_err(to_pyerr)?; let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us // Advertise ourself on etcd so ingress can find us
local_model local_model
.attach(&endpoint.inner, model_type_obj) .attach(&endpoint.inner, model_type_obj, model_input)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
...@@ -272,14 +274,44 @@ struct Client { ...@@ -272,14 +274,44 @@ struct Client {
router: rs::pipeline::PushRouter<serde_json::Value, RsAnnotated<serde_json::Value>>, router: rs::pipeline::PushRouter<serde_json::Value, RsAnnotated<serde_json::Value>>,
} }
#[pyclass]
#[derive(Clone, PartialEq)]
struct ModelType {
inner: llm_rs::model_type::ModelType,
}
#[pymethods]
#[allow(non_upper_case_globals)]
impl ModelType {
#[classattr]
const Chat: Self = ModelType {
inner: llm_rs::model_type::ModelType::Chat,
};
#[classattr]
const Completions: Self = ModelType {
inner: llm_rs::model_type::ModelType::Completions,
};
#[classattr]
const Embedding: Self = ModelType {
inner: llm_rs::model_type::ModelType::Embedding,
};
fn __or__(&self, other: &Self) -> Self {
ModelType {
inner: self.inner | other.inner,
}
}
fn __str__(&self) -> String {
self.inner.to_string()
}
}
#[pyclass(eq, eq_int)] #[pyclass(eq, eq_int)]
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
#[repr(i32)] enum ModelInput {
enum ModelType { Text = 1,
Chat = 1, Tokens = 2,
Completion = 2,
Backend = 3,
Embedding = 4,
} }
#[pymethods] #[pymethods]
......
...@@ -843,8 +843,12 @@ class HttpAsyncEngine: ...@@ -843,8 +843,12 @@ class HttpAsyncEngine:
... ...
class ModelInput:
"""What type of request this model needs: Text or Tokens"""
...
class ModelType: class ModelType:
"""What type of request this model needs: Chat, Component or Backend (pre-processed)""" """What type of request this model needs: Chat, Completions or Embedding"""
... ...
class RouterMode: class RouterMode:
...@@ -859,7 +863,19 @@ class KvRouterConfig: ...@@ -859,7 +863,19 @@ class KvRouterConfig:
"""Values for KV router""" """Values for KV router"""
... ...
async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: str, model_name: Optional[str] = None, context_length: Optional[int] = None, kv_cache_block_size: Optional[int] = None, router_mode: Optional[RouterMode] = None, migration_limit: int = 0, runtime_config: Optional[ModelRuntimeConfig] = None, user_data: Optional[dict] = None, custom_template_path: Optional[str] = None) -> None: async def register_llm(
model_input: ModelInput,
model_type: ModelType,
endpoint: Endpoint,
model_path: str,
model_name: Optional[str] = None,
context_length: Optional[int] = None,
kv_cache_block_size: Optional[int] = None,
migration_limit: int = 0,
router_mode: Optional[RouterMode] = None,
user_data: Optional[Dict[str, Any]] = None,
custom_template_path: Optional[str] = None,
) -> None:
"""Attach the model at path to the given endpoint, and advertise it as model_type""" """Attach the model at path to the given endpoint, and advertise it as model_type"""
... ...
......
...@@ -28,6 +28,7 @@ from dynamo._core import KvMetricsAggregator as KvMetricsAggregator ...@@ -28,6 +28,7 @@ from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvRecorder as KvRecorder from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouterConfig as KvRouterConfig from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import KvStats as KvStats from dynamo._core import KvStats as KvStats
from dynamo._core import ModelInput as ModelInput
from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig
from dynamo._core import ModelType as ModelType from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores from dynamo._core import OverlapScores as OverlapScores
......
...@@ -89,6 +89,7 @@ xxhash-rust = { workspace = true } ...@@ -89,6 +89,7 @@ xxhash-rust = { workspace = true }
# model_express_common = { version = "0.1.0", optional = true } # model_express_common = { version = "0.1.0", optional = true }
akin = "0.4.0" akin = "0.4.0"
bitflags = { version = "2.4", features = ["serde"] }
blake3 = "1" blake3 = "1"
bytemuck = "1.22" bytemuck = "1.22"
candle-core = { version = "0.9.1" } candle-core = { version = "0.9.1" }
......
...@@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize}; ...@@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize};
use crate::{ use crate::{
local_model::runtime_config::ModelRuntimeConfig, local_model::runtime_config::ModelRuntimeConfig,
model_card::{self, ModelDeploymentCard}, model_card::{self, ModelDeploymentCard},
model_type::ModelType, model_type::{ModelInput, ModelType},
}; };
/// [ModelEntry] contains the information to discover models from the etcd cluster. /// [ModelEntry] contains the information to discover models from the etcd cluster.
...@@ -34,6 +34,11 @@ pub struct ModelEntry { ...@@ -34,6 +34,11 @@ pub struct ModelEntry {
/// Runtime configuration specific to this model instance /// Runtime configuration specific to this model instance
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub runtime_config: Option<ModelRuntimeConfig>, pub runtime_config: Option<ModelRuntimeConfig>,
/// Specifies the model input type.
/// `Tokens` for engines that expect pre-processed input.
/// `Text` for engines that take care of pre-processing themselves.
pub model_input: ModelInput,
} }
impl ModelEntry { impl ModelEntry {
...@@ -43,7 +48,7 @@ impl ModelEntry { ...@@ -43,7 +48,7 @@ impl ModelEntry {
} }
pub fn requires_preprocessing(&self) -> bool { pub fn requires_preprocessing(&self) -> bool {
matches!(self.model_type, ModelType::Backend) matches!(self.model_input, ModelInput::Tokens)
} }
/// Fetch the ModelDeploymentCard from etcd. /// Fetch the ModelDeploymentCard from etcd.
......
...@@ -21,8 +21,8 @@ use crate::{ ...@@ -21,8 +21,8 @@ use crate::{
backend::Backend, backend::Backend,
entrypoint, entrypoint,
kv_router::KvRouterConfig, kv_router::KvRouterConfig,
model_type::ModelType, model_type::{ModelInput, ModelType},
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest}, preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
protocols::{ protocols::{
common::llm_backend::EmbeddingsEngineOutput, common::llm_backend::EmbeddingsEngineOutput,
openai::{ openai::{
...@@ -54,8 +54,11 @@ pub struct ModelWatcher { ...@@ -54,8 +54,11 @@ pub struct ModelWatcher {
busy_threshold: Option<f64>, busy_threshold: Option<f64>,
} }
const ALL_MODEL_TYPES: &[ModelType] = const ALL_MODEL_TYPES: &[ModelType] = &[
&[ModelType::Chat, ModelType::Completion, ModelType::Embedding]; ModelType::Chat,
ModelType::Completions,
ModelType::Embedding,
];
impl ModelWatcher { impl ModelWatcher {
pub fn new( pub fn new(
...@@ -236,7 +239,7 @@ impl ModelWatcher { ...@@ -236,7 +239,7 @@ impl ModelWatcher {
} else { } else {
for model_type in ALL_MODEL_TYPES { for model_type in ALL_MODEL_TYPES {
if ((chat_model_removed && *model_type == ModelType::Chat) if ((chat_model_removed && *model_type == ModelType::Chat)
|| (completions_model_removed && *model_type == ModelType::Completion) || (completions_model_removed && *model_type == ModelType::Completions)
|| (embeddings_model_removed && *model_type == ModelType::Embedding)) || (embeddings_model_removed && *model_type == ModelType::Embedding))
&& let Some(tx) = &self.model_update_tx && let Some(tx) = &self.model_update_tx
{ {
...@@ -273,35 +276,40 @@ impl ModelWatcher { ...@@ -273,35 +276,40 @@ impl ModelWatcher {
} }
}; };
match model_entry.model_type { if model_entry.model_input == ModelInput::Tokens
ModelType::Backend => { && (model_entry.model_type.supports_chat()
// A Backend model expects pre-processed requests meaning it's up to us whether we || model_entry.model_type.supports_completions())
// handle Chat or Completions requests, so handle both. {
// Case 1: Tokens + (Chat OR Completions OR Both)
let Some(mut card) = card else { // A model that expects pre-processed requests meaning it's up to us whether we
anyhow::bail!("Missing model deployment card"); // handle Chat or Completions requests, so handle whatever the model supports.
};
// Download tokenizer.json etc to local disk let Some(mut card) = card else {
// This cache_dir is a tempfile::TempDir will be deleted on drop. I _think_ anyhow::bail!("Missing model deployment card");
// OpenAIPreprocessor::new loads the files, so we can delete them after this };
// function. Needs checking carefully, possibly we need to store it in state. // Download tokenizer.json etc to local disk
let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?); // This cache_dir is a tempfile::TempDir will be deleted on drop. I _think_
// OpenAIPreprocessor::new loads the files, so we can delete them after this
let kv_chooser = if self.router_mode == RouterMode::KV { // function. Needs checking carefully, possibly we need to store it in state.
Some( let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);
self.manager
.kv_chooser_for( let kv_chooser = if self.router_mode == RouterMode::KV {
&model_entry.name, Some(
&component, self.manager
card.kv_cache_block_size, .kv_chooser_for(
self.kv_router_config, &model_entry.name,
) &component,
.await?, card.kv_cache_block_size,
) self.kv_router_config,
} else { )
None .await?,
}; )
} else {
None
};
// Add chat engine only if the model supports chat
if model_entry.model_type.supports_chat() {
let chat_engine = entrypoint::build_routed_pipeline::< let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
...@@ -315,8 +323,15 @@ impl ModelWatcher { ...@@ -315,8 +323,15 @@ impl ModelWatcher {
.await?; .await?;
self.manager self.manager
.add_chat_completions_model(&model_entry.name, chat_engine)?; .add_chat_completions_model(&model_entry.name, chat_engine)?;
}
let completions_engine = entrypoint::build_routed_pipeline::< // Add completions engine only if the model supports completions
if model_entry.model_type.supports_completions() {
let formatter = PromptFormatter::no_op();
let PromptFormatter::OAI(formatter) = formatter;
let preprocessor =
OpenAIPreprocessor::new_with_formatter(card.clone(), formatter).await?;
let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
NvCreateCompletionResponse, NvCreateCompletionResponse,
>( >(
...@@ -325,75 +340,90 @@ impl ModelWatcher { ...@@ -325,75 +340,90 @@ impl ModelWatcher {
self.router_mode, self.router_mode,
self.busy_threshold, self.busy_threshold,
kv_chooser, kv_chooser,
preprocessor,
) )
.await?; .await?;
self.manager self.manager
.add_completions_model(&model_entry.name, completions_engine)?; .add_completions_model(&model_entry.name, completions_engine)?;
} }
ModelType::Chat => { } else if model_entry.model_input == ModelInput::Text
let push_router = PushRouter::< && model_entry.model_type.supports_chat()
NvCreateChatCompletionRequest, {
Annotated<NvCreateChatCompletionStreamResponse>, // Case 3: Text + Chat
>::from_client_with_threshold( let push_router = PushRouter::<
client, Default::default(), self.busy_threshold NvCreateChatCompletionRequest,
) Annotated<NvCreateChatCompletionStreamResponse>,
.await?; >::from_client_with_threshold(
let engine = Arc::new(push_router); client, self.router_mode, self.busy_threshold
self.manager )
.add_chat_completions_model(&model_entry.name, engine)?; .await?;
} let engine = Arc::new(push_router);
ModelType::Completion => { self.manager
let push_router = PushRouter::< .add_chat_completions_model(&model_entry.name, engine)?;
NvCreateCompletionRequest, } else if model_entry.model_input == ModelInput::Text
Annotated<NvCreateCompletionResponse>, && model_entry.model_type.supports_completions()
>::from_client_with_threshold( {
client, Default::default(), self.busy_threshold // Case 2: Text + Completions
) let push_router = PushRouter::<
.await?; NvCreateCompletionRequest,
let engine = Arc::new(push_router); Annotated<NvCreateCompletionResponse>,
self.manager >::from_client_with_threshold(
.add_completions_model(&model_entry.name, engine)?; client, self.router_mode, self.busy_threshold
} )
ModelType::Embedding => { .await?;
let Some(mut card) = card else { let engine = Arc::new(push_router);
anyhow::bail!("Missing model deployment card for embedding model"); self.manager
}; .add_completions_model(&model_entry.name, engine)?;
} else if model_entry.model_input == ModelInput::Tokens
// Download tokenizer files to local disk && model_entry.model_type.supports_embedding()
let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?); {
// Case 4: Tokens + Embeddings
// Create preprocessing pipeline similar to Backend let Some(mut card) = card else {
let frontend = SegmentSource::< anyhow::bail!("Missing model deployment card for embedding model");
SingleIn<NvCreateEmbeddingRequest>, };
ManyOut<Annotated<NvCreateEmbeddingResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router = PushRouter::<
PreprocessedEmbeddingRequest,
Annotated<EmbeddingsEngineOutput>,
>::from_client_with_threshold(
client, self.router_mode, self.busy_threshold
)
.await?;
// Note: Embeddings don't need KV routing complexity
let service_backend = ServiceBackend::from_engine(Arc::new(router));
// Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
let embedding_engine = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(service_backend)?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
self.manager // Download tokenizer files to local disk
.add_embeddings_model(&model_entry.name, embedding_engine)?; let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);
}
// Create preprocessing pipeline similar to Backend
let frontend = SegmentSource::<
SingleIn<NvCreateEmbeddingRequest>,
ManyOut<Annotated<NvCreateEmbeddingResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router = PushRouter::<
PreprocessedEmbeddingRequest,
Annotated<EmbeddingsEngineOutput>,
>::from_client_with_threshold(
client, self.router_mode, self.busy_threshold
)
.await?;
// Note: Embeddings don't need KV routing complexity
let service_backend = ServiceBackend::from_engine(Arc::new(router));
// Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
let embedding_engine = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(service_backend)?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
self.manager
.add_embeddings_model(&model_entry.name, embedding_engine)?;
} else {
// Reject unsupported combinations
anyhow::bail!(
"Unsupported model configuration: {} with {} input. Supported combinations: \
Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings",
model_entry.model_type,
model_entry.model_input.as_str()
);
} }
Ok(()) Ok(())
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
//! - Connect it to an Input //! - Connect it to an Input
pub mod input; pub mod input;
pub use input::build_routed_pipeline; pub use input::{build_routed_pipeline, build_routed_pipeline_with_preprocessor};
use std::sync::Arc; use std::sync::Arc;
......
...@@ -16,7 +16,7 @@ use std::{ ...@@ -16,7 +16,7 @@ use std::{
pub mod batch; pub mod batch;
mod common; mod common;
pub use common::build_routed_pipeline; pub use common::{build_routed_pipeline, build_routed_pipeline_with_preprocessor};
pub mod endpoint; pub mod endpoint;
pub mod grpc; pub mod grpc;
pub mod http; pub mod http;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment