Unverified Commit 2ace5a4a authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: multimodal endpoint registration via cli modality (#6270)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 3081ca17
...@@ -3,12 +3,13 @@ ...@@ -3,12 +3,13 @@
"""Dynamo runtime configuration ArgGroup.""" """Dynamo runtime configuration ArgGroup."""
from typing import Optional from typing import List, Optional
from dynamo._core import get_reasoning_parser_names, get_tool_parser_names from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
from dynamo.common.configuration.arg_group import ArgGroup from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from dynamo.common.utils.output_modalities import OutputModality
class DynamoRuntimeConfig(ConfigBase): class DynamoRuntimeConfig(ConfigBase):
...@@ -29,10 +30,25 @@ class DynamoRuntimeConfig(ConfigBase): ...@@ -29,10 +30,25 @@ class DynamoRuntimeConfig(ConfigBase):
endpoint_types: str endpoint_types: str
dump_config_to: Optional[str] = None dump_config_to: Optional[str] = None
multimodal_embedding_cache_capacity_gb: float multimodal_embedding_cache_capacity_gb: float
output_modalities: List[str]
def validate(self) -> None: def validate(self) -> None:
# TODO get a better way for spot fixes like this. # TODO get a better way for spot fixes like this.
self.enable_local_indexer = not self.durable_kv_events self.enable_local_indexer = not self.durable_kv_events
self._validate_output_modalities()
def _validate_output_modalities(self) -> None:
"""Validate --output-modalities values."""
if not self.output_modalities:
return
valid = OutputModality.valid_names()
normalized = [m.lower() for m in self.output_modalities]
invalid = [m for m in normalized if m not in valid]
if invalid:
raise ValueError(
f"Invalid output modality: {', '.join(invalid)}. "
f"Valid options are: {', '.join(sorted(valid))}"
)
# For simplicity, we do not prepend "dyn-" unless it's absolutely necessary. These are # For simplicity, we do not prepend "dyn-" unless it's absolutely necessary. These are
...@@ -151,3 +167,12 @@ class DynamoRuntimeArgGroup(ArgGroup): ...@@ -151,3 +167,12 @@ class DynamoRuntimeArgGroup(ArgGroup):
arg_type=float, arg_type=float,
help="Capacity of the multimodal embedding cache in GB. 0 = disabled.", help="Capacity of the multimodal embedding cache in GB. 0 = disabled.",
) )
add_argument(
g,
flag_name="--output-modalities",
env_var="DYN_OUTPUT_MODALITIES",
default=["text"],
help="Output modalities for omni/diffusion mode (e.g., --output-modalities text image audio video).",
nargs="*",
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from enum import Enum
from typing import List, Optional
from dynamo.llm import ModelType
class OutputModality(Enum):
"""Maps CLI modality names to their corresponding ModelType flags."""
TEXT = (ModelType.Chat, ModelType.Completions)
IMAGE = (ModelType.Images, ModelType.Chat)
VIDEO = (ModelType.Videos,)
AUDIO = (ModelType.Audios,)
@classmethod
def from_name(cls, name: str) -> "OutputModality":
"""Look up a modality by its CLI name (case-insensitive)."""
try:
return cls[name.upper()]
except KeyError:
valid = ", ".join(m.name.lower() for m in cls)
raise ValueError(
f"Unknown output modality: {name!r}. Valid options: {valid}"
)
@classmethod
def valid_names(cls) -> set:
"""Return the set of valid CLI modality names (lowercase)."""
return {m.name.lower() for m in cls}
def get_output_modalities(cli_input: List[str], model_repo: str) -> Optional[ModelType]:
"""
Get the combined ModelType flags for omni models based on CLI input.
Args:
cli_input: List of modality name strings (e.g. ["text", "image"]).
model_repo: Model repo string (reserved for future per-model logic).
Returns:
Combined ModelType flags, or None if no recognized modalities are present.
"""
# For now, we ignore model repo and just use cli input to determine output modalities.
output_modalities = None
for name in cli_input:
modality = OutputModality.from_name(name)
for flag in modality.value:
output_modalities = (
flag if output_modalities is None else output_modalities | flag
)
return output_modalities
...@@ -630,6 +630,7 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config): ...@@ -630,6 +630,7 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
generator, generator,
generate_endpoint, generate_endpoint,
server_args, server_args,
output_modalities=dynamo_args.output_modalities,
readiness_gate=ready_event, readiness_gate=ready_event,
), ),
) )
......
...@@ -4,13 +4,14 @@ ...@@ -4,13 +4,14 @@
import asyncio import asyncio
import logging import logging
import socket import socket
from typing import Any, Optional from typing import Any, List, Optional
import sglang as sgl import sglang as sgl
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_local_ip_auto from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Endpoint from dynamo._core import Endpoint
from dynamo.common.utils.output_modalities import get_output_modalities
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_model from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_model
from dynamo.sglang.args import DynamoConfig from dynamo.sglang.args import DynamoConfig
...@@ -278,6 +279,7 @@ async def register_image_diffusion_model( ...@@ -278,6 +279,7 @@ async def register_image_diffusion_model(
generator: Any, # DiffGenerator generator: Any, # DiffGenerator
endpoint: Endpoint, endpoint: Endpoint,
server_args: ServerArgs, server_args: ServerArgs,
output_modalities: Optional[List[str]] = None,
readiness_gate: Optional[asyncio.Event] = None, readiness_gate: Optional[asyncio.Event] = None,
) -> None: ) -> None:
"""Register diffusion model with Dynamo runtime. """Register diffusion model with Dynamo runtime.
...@@ -286,18 +288,37 @@ async def register_image_diffusion_model( ...@@ -286,18 +288,37 @@ async def register_image_diffusion_model(
generator: The SGLang DiffGenerator instance. generator: The SGLang DiffGenerator instance.
endpoint: The Dynamo endpoint for generation requests. endpoint: The Dynamo endpoint for generation requests.
server_args: SGLang server configuration. server_args: SGLang server configuration.
output_modalities: Optional list of output modality names to override
the default ModelType.Images registration.
readiness_gate: Optional event to signal when registration completes. readiness_gate: Optional event to signal when registration completes.
Note: Note:
Image diffusion models use ModelInput.Text (text prompts) and ModelType.Images. Image diffusion models use ModelInput.Text (text prompts) and ModelType.Images
by default. When output_modalities is provided, the ModelType is derived
from the given modality names instead.
""" """
# Use model_path as the model name (diffusion workers don't have served_model_name) # Use model_path as the model name (diffusion workers don't have served_model_name)
model_name = server_args.model_path model_name = server_args.model_path
model_type = ModelType.Images
if output_modalities:
resolved = get_output_modalities(output_modalities, model_name)
if resolved is not None:
model_type = resolved
logging.info(
"Using output modalities %s for diffusion model registration",
output_modalities,
)
else:
logging.warning(
"No recognized output modalities from %s, defaulting to ModelType.Images",
output_modalities,
)
try: try:
await register_model( await register_model(
ModelInput.Text, ModelInput.Text,
ModelType.Images, model_type,
endpoint, endpoint,
model_name, model_name,
model_name, model_name,
......
...@@ -18,6 +18,7 @@ from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus ...@@ -18,6 +18,7 @@ from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from dynamo import prometheus_names from dynamo import prometheus_names
from dynamo.common.config_dump import dump_config from dynamo.common.config_dump import dump_config
from dynamo.common.utils.endpoint_types import parse_endpoint_types from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.output_modalities import get_output_modalities
from dynamo.common.utils.prometheus import ( from dynamo.common.utils.prometheus import (
LLMBackendMetrics, LLMBackendMetrics,
register_engine_metrics_callback, register_engine_metrics_callback,
...@@ -1117,9 +1118,13 @@ async def init_omni( ...@@ -1117,9 +1118,13 @@ async def init_omni(
return return
# TODO: extend for multi-stage pipelines # TODO: extend for multi-stage pipelines
model_type = get_output_modalities(config.output_modalities, config.model)
if model_type is None:
# Default to Images
model_type = ModelType.Images
await register_model( await register_model(
ModelInput.Text, ModelInput.Text,
ModelType.Images, model_type,
generate_endpoint, generate_endpoint,
config.model, config.model,
config.served_model_name, config.served_model_name,
......
...@@ -529,6 +529,10 @@ impl ModelType { ...@@ -529,6 +529,10 @@ impl ModelType {
inner: llm_rs::model_type::ModelType::Images, inner: llm_rs::model_type::ModelType::Images,
}; };
#[classattr] #[classattr]
const Audios: Self = ModelType {
inner: llm_rs::model_type::ModelType::Audios,
};
#[classattr]
const Videos: Self = ModelType { const Videos: Self = ModelType {
inner: llm_rs::model_type::ModelType::Videos, inner: llm_rs::model_type::ModelType::Videos,
}; };
......
...@@ -922,6 +922,7 @@ class ModelType: ...@@ -922,6 +922,7 @@ class ModelType:
TensorBased: ModelType TensorBased: ModelType
Prefill: ModelType Prefill: ModelType
Images: ModelType Images: ModelType
Audios: ModelType
Videos: ModelType Videos: ModelType
... ...
......
...@@ -72,6 +72,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[ ...@@ -72,6 +72,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[
ModelType::Completions, ModelType::Completions,
ModelType::Embedding, ModelType::Embedding,
ModelType::Images, ModelType::Images,
ModelType::Audios,
ModelType::Videos, ModelType::Videos,
ModelType::TensorBased, ModelType::TensorBased,
ModelType::Prefill, ModelType::Prefill,
...@@ -659,45 +660,60 @@ impl ModelWatcher { ...@@ -659,45 +660,60 @@ impl ModelWatcher {
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
.add_tensor_model(card.name(), checksum, engine)?; .add_tensor_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_images() { }
// Case: Text + Images (e.g. vLLM-Omni, diffusion models) // Case: Text + (Images, Audio, Videos)
// Takes text prompts as input, generates images. Images models also support else if card.model_input == ModelInput::Text
// chat completions (see model_type.rs as_endpoint_types). && (card.model_type.supports_images()
let images_router = PushRouter::< || card.model_type.supports_audios()
NvCreateImageRequest, || card.model_type.supports_videos())
Annotated<NvImagesResponse>, {
>::from_client_with_threshold( // Image Models can support chat completions (vllm omni way)
client.clone(), self.router_config.router_mode, None, None // So register chat_completions model as well
) if card.model_type.supports_chat() {
.await?; let chat_router = PushRouter::<
self.manager NvCreateChatCompletionRequest,
.add_images_model(card.name(), checksum, Arc::new(images_router))?; Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client_with_threshold(
client.clone(),
self.router_config.router_mode,
None,
None,
)
.await?;
self.manager.add_chat_completions_model(
card.name(),
checksum,
Arc::new(chat_router),
)?;
}
let chat_router = PushRouter::< // This is ModelType::Images : registers /v1/images/* endpoints
NvCreateChatCompletionRequest, if card.model_type.supports_images() {
Annotated<NvCreateChatCompletionStreamResponse>, let images_router = PushRouter::<
>::from_client_with_threshold( NvCreateImageRequest,
client, self.router_config.router_mode, None, None Annotated<NvImagesResponse>,
) >::from_client_with_threshold(
.await?; client.clone(), self.router_config.router_mode, None, None
self.manager.add_chat_completions_model( )
card.name(), .await?;
checksum, self.manager
Arc::new(chat_router), .add_images_model(card.name(), checksum, Arc::new(images_router))?;
)?; }
} else if card.model_input == ModelInput::Text && card.model_type.supports_videos() {
// Case: Text + Videos (video generation models) // This is ModelType::Videos : registers /v1/videos/* endpoints
// Takes text prompts as input, generates videos if card.model_type.supports_videos() {
let push_router = PushRouter::< let videos_router = PushRouter::<
NvCreateVideoRequest, NvCreateVideoRequest,
Annotated<NvVideosResponse>, Annotated<NvVideosResponse>,
>::from_client_with_threshold( >::from_client_with_threshold(
client, self.router_config.router_mode, None, None client.clone(), self.router_config.router_mode, None, None
) )
.await?; .await?;
let engine = Arc::new(push_router); self.manager
self.manager .add_videos_model(card.name(), checksum, Arc::new(videos_router))?;
.add_videos_model(card.name(), checksum, engine)?; }
// TODO: add audio models support
} else if card.model_type.supports_prefill() { } else if card.model_type.supports_prefill() {
// Case 6: Prefill // Case 6: Prefill
// Guardrail: Verify model_input is Tokens // Guardrail: Verify model_input is Tokens
......
...@@ -14,7 +14,9 @@ pub enum EndpointType { ...@@ -14,7 +14,9 @@ pub enum EndpointType {
Embedding, Embedding,
/// Images API (Diffusion/DALL-E) /// Images API (Diffusion/DALL-E)
Images, Images,
/// Videos API (Video Generation) /// Audios API (speech/audio generation)
Audios,
/// Videos API (video generation)
Videos, Videos,
/// Responses API /// Responses API
Responses, Responses,
...@@ -29,6 +31,7 @@ impl EndpointType { ...@@ -29,6 +31,7 @@ impl EndpointType {
Self::Completion => "completion", Self::Completion => "completion",
Self::Embedding => "embedding", Self::Embedding => "embedding",
Self::Images => "images", Self::Images => "images",
Self::Audios => "audios",
Self::Videos => "videos", Self::Videos => "videos",
Self::Responses => "responses", Self::Responses => "responses",
Self::AnthropicMessages => "anthropic_messages", Self::AnthropicMessages => "anthropic_messages",
...@@ -41,6 +44,7 @@ impl EndpointType { ...@@ -41,6 +44,7 @@ impl EndpointType {
Self::Completion, Self::Completion,
Self::Embedding, Self::Embedding,
Self::Images, Self::Images,
Self::Audios,
Self::Videos, Self::Videos,
Self::Responses, Self::Responses,
Self::AnthropicMessages, Self::AnthropicMessages,
......
...@@ -60,6 +60,8 @@ impl StateFlags { ...@@ -60,6 +60,8 @@ impl StateFlags {
EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Images => self.images_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Images => self.images_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Videos => self.videos_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Videos => self.videos_endpoints_enabled.load(Ordering::Relaxed),
// TODO: add audios_endpoints_enabled flag
EndpointType::Audios => false,
EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::AnthropicMessages => { EndpointType::AnthropicMessages => {
self.anthropic_endpoints_enabled.load(Ordering::Relaxed) self.anthropic_endpoints_enabled.load(Ordering::Relaxed)
...@@ -84,6 +86,8 @@ impl StateFlags { ...@@ -84,6 +86,8 @@ impl StateFlags {
EndpointType::Videos => self EndpointType::Videos => self
.videos_endpoints_enabled .videos_endpoints_enabled
.store(enabled, Ordering::Relaxed), .store(enabled, Ordering::Relaxed),
// TODO: add audios_endpoints_enabled flag
EndpointType::Audios => {}
EndpointType::Responses => self EndpointType::Responses => self
.responses_endpoints_enabled .responses_endpoints_enabled
.store(enabled, Ordering::Relaxed), .store(enabled, Ordering::Relaxed),
......
...@@ -38,7 +38,8 @@ bitflags! { ...@@ -38,7 +38,8 @@ bitflags! {
const TensorBased = 1 << 3; const TensorBased = 1 << 3;
const Prefill = 1 << 4; const Prefill = 1 << 4;
const Images = 1 << 5; const Images = 1 << 5;
const Videos = 1 << 6; const Audios = 1 << 6;
const Videos = 1 << 7;
} }
} }
...@@ -65,6 +66,9 @@ impl ModelType { ...@@ -65,6 +66,9 @@ impl ModelType {
pub fn supports_images(&self) -> bool { pub fn supports_images(&self) -> bool {
self.contains(ModelType::Images) self.contains(ModelType::Images)
} }
pub fn supports_audios(&self) -> bool {
self.contains(ModelType::Audios)
}
pub fn supports_videos(&self) -> bool { pub fn supports_videos(&self) -> bool {
self.contains(ModelType::Videos) self.contains(ModelType::Videos)
} }
...@@ -89,6 +93,9 @@ impl ModelType { ...@@ -89,6 +93,9 @@ impl ModelType {
if self.supports_images() { if self.supports_images() {
result.push("images"); result.push("images");
} }
if self.supports_audios() {
result.push("audios");
}
if self.supports_videos() { if self.supports_videos() {
result.push("videos"); result.push("videos");
} }
...@@ -117,6 +124,9 @@ impl ModelType { ...@@ -117,6 +124,9 @@ impl ModelType {
if self.supports_images() { if self.supports_images() {
result.push(ModelType::Images); result.push(ModelType::Images);
} }
if self.supports_audios() {
result.push(ModelType::Audios);
}
if self.supports_videos() { if self.supports_videos() {
result.push(ModelType::Videos); result.push(ModelType::Videos);
} }
...@@ -147,7 +157,9 @@ impl ModelType { ...@@ -147,7 +157,9 @@ impl ModelType {
// Images models support both chat and completions endpoints // Images models support both chat and completions endpoints
if self.contains(Self::Images) { if self.contains(Self::Images) {
endpoint_types.push(crate::endpoint_type::EndpointType::Images); endpoint_types.push(crate::endpoint_type::EndpointType::Images);
endpoint_types.push(crate::endpoint_type::EndpointType::Chat); }
if self.contains(Self::Audios) {
endpoint_types.push(crate::endpoint_type::EndpointType::Audios);
} }
if self.contains(Self::Videos) { if self.contains(Self::Videos) {
endpoint_types.push(crate::endpoint_type::EndpointType::Videos); endpoint_types.push(crate::endpoint_type::EndpointType::Videos);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment