Unverified Commit 2ace5a4a authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: multimodal endpoint registration via cli modality (#6270)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 3081ca17
......@@ -3,12 +3,13 @@
"""Dynamo runtime configuration ArgGroup."""
from typing import Optional
from typing import List, Optional
from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from dynamo.common.utils.output_modalities import OutputModality
class DynamoRuntimeConfig(ConfigBase):
......@@ -29,10 +30,25 @@ class DynamoRuntimeConfig(ConfigBase):
endpoint_types: str
dump_config_to: Optional[str] = None
multimodal_embedding_cache_capacity_gb: float
output_modalities: List[str]
def validate(self) -> None:
# TODO get a better way for spot fixes like this.
self.enable_local_indexer = not self.durable_kv_events
self._validate_output_modalities()
def _validate_output_modalities(self) -> None:
"""Validate --output-modalities values."""
if not self.output_modalities:
return
valid = OutputModality.valid_names()
normalized = [m.lower() for m in self.output_modalities]
invalid = [m for m in normalized if m not in valid]
if invalid:
raise ValueError(
f"Invalid output modality: {', '.join(invalid)}. "
f"Valid options are: {', '.join(sorted(valid))}"
)
# For simplicity, we do not prepend "dyn-" unless it's absolutely necessary. These are
......@@ -151,3 +167,12 @@ class DynamoRuntimeArgGroup(ArgGroup):
arg_type=float,
help="Capacity of the multimodal embedding cache in GB. 0 = disabled.",
)
add_argument(
g,
flag_name="--output-modalities",
env_var="DYN_OUTPUT_MODALITIES",
default=["text"],
help="Output modalities for omni/diffusion mode (e.g., --output-modalities text image audio video).",
nargs="*",
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from enum import Enum
from typing import List, Optional
from dynamo.llm import ModelType
class OutputModality(Enum):
"""Maps CLI modality names to their corresponding ModelType flags."""
TEXT = (ModelType.Chat, ModelType.Completions)
IMAGE = (ModelType.Images, ModelType.Chat)
VIDEO = (ModelType.Videos,)
AUDIO = (ModelType.Audios,)
@classmethod
def from_name(cls, name: str) -> "OutputModality":
"""Look up a modality by its CLI name (case-insensitive)."""
try:
return cls[name.upper()]
except KeyError:
valid = ", ".join(m.name.lower() for m in cls)
raise ValueError(
f"Unknown output modality: {name!r}. Valid options: {valid}"
)
@classmethod
def valid_names(cls) -> set:
"""Return the set of valid CLI modality names (lowercase)."""
return {m.name.lower() for m in cls}
def get_output_modalities(cli_input: List[str], model_repo: str) -> Optional[ModelType]:
"""
Get the combined ModelType flags for omni models based on CLI input.
Args:
cli_input: List of modality name strings (e.g. ["text", "image"]).
model_repo: Model repo string (reserved for future per-model logic).
Returns:
Combined ModelType flags, or None if no recognized modalities are present.
"""
# For now, we ignore model repo and just use cli input to determine output modalities.
output_modalities = None
for name in cli_input:
modality = OutputModality.from_name(name)
for flag in modality.value:
output_modalities = (
flag if output_modalities is None else output_modalities | flag
)
return output_modalities
......@@ -630,6 +630,7 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
generator,
generate_endpoint,
server_args,
output_modalities=dynamo_args.output_modalities,
readiness_gate=ready_event,
),
)
......
......@@ -4,13 +4,14 @@
import asyncio
import logging
import socket
from typing import Any, Optional
from typing import Any, List, Optional
import sglang as sgl
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Endpoint
from dynamo.common.utils.output_modalities import get_output_modalities
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_model
from dynamo.sglang.args import DynamoConfig
......@@ -278,6 +279,7 @@ async def register_image_diffusion_model(
generator: Any, # DiffGenerator
endpoint: Endpoint,
server_args: ServerArgs,
output_modalities: Optional[List[str]] = None,
readiness_gate: Optional[asyncio.Event] = None,
) -> None:
"""Register diffusion model with Dynamo runtime.
......@@ -286,18 +288,37 @@ async def register_image_diffusion_model(
generator: The SGLang DiffGenerator instance.
endpoint: The Dynamo endpoint for generation requests.
server_args: SGLang server configuration.
output_modalities: Optional list of output modality names to override
the default ModelType.Images registration.
readiness_gate: Optional event to signal when registration completes.
Note:
Image diffusion models use ModelInput.Text (text prompts) and ModelType.Images.
Image diffusion models use ModelInput.Text (text prompts) and ModelType.Images
by default. When output_modalities is provided, the ModelType is derived
from the given modality names instead.
"""
# Use model_path as the model name (diffusion workers don't have served_model_name)
model_name = server_args.model_path
model_type = ModelType.Images
if output_modalities:
resolved = get_output_modalities(output_modalities, model_name)
if resolved is not None:
model_type = resolved
logging.info(
"Using output modalities %s for diffusion model registration",
output_modalities,
)
else:
logging.warning(
"No recognized output modalities from %s, defaulting to ModelType.Images",
output_modalities,
)
try:
await register_model(
ModelInput.Text,
ModelType.Images,
model_type,
endpoint,
model_name,
model_name,
......
......@@ -18,6 +18,7 @@ from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from dynamo import prometheus_names
from dynamo.common.config_dump import dump_config
from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.output_modalities import get_output_modalities
from dynamo.common.utils.prometheus import (
LLMBackendMetrics,
register_engine_metrics_callback,
......@@ -1117,9 +1118,13 @@ async def init_omni(
return
# TODO: extend for multi-stage pipelines
model_type = get_output_modalities(config.output_modalities, config.model)
if model_type is None:
# Default to Images
model_type = ModelType.Images
await register_model(
ModelInput.Text,
ModelType.Images,
model_type,
generate_endpoint,
config.model,
config.served_model_name,
......
......@@ -529,6 +529,10 @@ impl ModelType {
inner: llm_rs::model_type::ModelType::Images,
};
#[classattr]
const Audios: Self = ModelType {
inner: llm_rs::model_type::ModelType::Audios,
};
#[classattr]
const Videos: Self = ModelType {
inner: llm_rs::model_type::ModelType::Videos,
};
......
......@@ -922,6 +922,7 @@ class ModelType:
TensorBased: ModelType
Prefill: ModelType
Images: ModelType
Audios: ModelType
Videos: ModelType
...
......
......@@ -72,6 +72,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[
ModelType::Completions,
ModelType::Embedding,
ModelType::Images,
ModelType::Audios,
ModelType::Videos,
ModelType::TensorBased,
ModelType::Prefill,
......@@ -659,45 +660,60 @@ impl ModelWatcher {
let engine = Arc::new(push_router);
self.manager
.add_tensor_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_images() {
// Case: Text + Images (e.g. vLLM-Omni, diffusion models)
// Takes text prompts as input, generates images. Images models also support
// chat completions (see model_type.rs as_endpoint_types).
let images_router = PushRouter::<
NvCreateImageRequest,
Annotated<NvImagesResponse>,
>::from_client_with_threshold(
client.clone(), self.router_config.router_mode, None, None
)
.await?;
self.manager
.add_images_model(card.name(), checksum, Arc::new(images_router))?;
}
// Case: Text + (Images, Audio, Videos)
else if card.model_input == ModelInput::Text
&& (card.model_type.supports_images()
|| card.model_type.supports_audios()
|| card.model_type.supports_videos())
{
// Image Models can support chat completions (vllm omni way)
// So register chat_completions model as well
if card.model_type.supports_chat() {
let chat_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client_with_threshold(
client.clone(),
self.router_config.router_mode,
None,
None,
)
.await?;
self.manager.add_chat_completions_model(
card.name(),
checksum,
Arc::new(chat_router),
)?;
}
let chat_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client_with_threshold(
client, self.router_config.router_mode, None, None
)
.await?;
self.manager.add_chat_completions_model(
card.name(),
checksum,
Arc::new(chat_router),
)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_videos() {
// Case: Text + Videos (video generation models)
// Takes text prompts as input, generates videos
let push_router = PushRouter::<
NvCreateVideoRequest,
Annotated<NvVideosResponse>,
>::from_client_with_threshold(
client, self.router_config.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_videos_model(card.name(), checksum, engine)?;
// This is ModelType::Images : registers /v1/images/* endpoints
if card.model_type.supports_images() {
let images_router = PushRouter::<
NvCreateImageRequest,
Annotated<NvImagesResponse>,
>::from_client_with_threshold(
client.clone(), self.router_config.router_mode, None, None
)
.await?;
self.manager
.add_images_model(card.name(), checksum, Arc::new(images_router))?;
}
// This is ModelType::Videos : registers /v1/videos/* endpoints
if card.model_type.supports_videos() {
let videos_router = PushRouter::<
NvCreateVideoRequest,
Annotated<NvVideosResponse>,
>::from_client_with_threshold(
client.clone(), self.router_config.router_mode, None, None
)
.await?;
self.manager
.add_videos_model(card.name(), checksum, Arc::new(videos_router))?;
}
// TODO: add audio models support
} else if card.model_type.supports_prefill() {
// Case 6: Prefill
// Guardrail: Verify model_input is Tokens
......
......@@ -14,7 +14,9 @@ pub enum EndpointType {
Embedding,
/// Images API (Diffusion/DALL-E)
Images,
/// Videos API (Video Generation)
/// Audios API (speech/audio generation)
Audios,
/// Videos API (video generation)
Videos,
/// Responses API
Responses,
......@@ -29,6 +31,7 @@ impl EndpointType {
Self::Completion => "completion",
Self::Embedding => "embedding",
Self::Images => "images",
Self::Audios => "audios",
Self::Videos => "videos",
Self::Responses => "responses",
Self::AnthropicMessages => "anthropic_messages",
......@@ -41,6 +44,7 @@ impl EndpointType {
Self::Completion,
Self::Embedding,
Self::Images,
Self::Audios,
Self::Videos,
Self::Responses,
Self::AnthropicMessages,
......
......@@ -60,6 +60,8 @@ impl StateFlags {
EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Images => self.images_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Videos => self.videos_endpoints_enabled.load(Ordering::Relaxed),
// TODO: add audios_endpoints_enabled flag
EndpointType::Audios => false,
EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::AnthropicMessages => {
self.anthropic_endpoints_enabled.load(Ordering::Relaxed)
......@@ -84,6 +86,8 @@ impl StateFlags {
EndpointType::Videos => self
.videos_endpoints_enabled
.store(enabled, Ordering::Relaxed),
// TODO: add audios_endpoints_enabled flag
EndpointType::Audios => {}
EndpointType::Responses => self
.responses_endpoints_enabled
.store(enabled, Ordering::Relaxed),
......
......@@ -38,7 +38,8 @@ bitflags! {
const TensorBased = 1 << 3;
const Prefill = 1 << 4;
const Images = 1 << 5;
const Videos = 1 << 6;
const Audios = 1 << 6;
const Videos = 1 << 7;
}
}
......@@ -65,6 +66,9 @@ impl ModelType {
pub fn supports_images(&self) -> bool {
self.contains(ModelType::Images)
}
pub fn supports_audios(&self) -> bool {
self.contains(ModelType::Audios)
}
pub fn supports_videos(&self) -> bool {
self.contains(ModelType::Videos)
}
......@@ -89,6 +93,9 @@ impl ModelType {
if self.supports_images() {
result.push("images");
}
if self.supports_audios() {
result.push("audios");
}
if self.supports_videos() {
result.push("videos");
}
......@@ -117,6 +124,9 @@ impl ModelType {
if self.supports_images() {
result.push(ModelType::Images);
}
if self.supports_audios() {
result.push(ModelType::Audios);
}
if self.supports_videos() {
result.push(ModelType::Videos);
}
......@@ -147,7 +157,9 @@ impl ModelType {
// Images models support both chat and completions endpoints
if self.contains(Self::Images) {
endpoint_types.push(crate::endpoint_type::EndpointType::Images);
endpoint_types.push(crate::endpoint_type::EndpointType::Chat);
}
if self.contains(Self::Audios) {
endpoint_types.push(crate::endpoint_type::EndpointType::Audios);
}
if self.contains(Self::Videos) {
endpoint_types.push(crate::endpoint_type::EndpointType::Videos);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment