Unverified Commit 76e0e207 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: basic vllm omni pipeline support (#5608)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent d22ca523
...@@ -80,6 +80,11 @@ class Config: ...@@ -80,6 +80,11 @@ class Config:
ec_extra_config: Optional[str] = None ec_extra_config: Optional[str] = None
ec_consumer_mode: bool = False ec_consumer_mode: bool = False
# vLLM-Omni worker for multi-stage pipelines
omni: bool = False
# Path to vLLM-Omni stage configuration YAML
stage_configs_path: Optional[str] = None
# dump config to file # dump config to file
dump_config_to: Optional[str] = None dump_config_to: Optional[str] = None
...@@ -258,6 +263,17 @@ def parse_args() -> Config: ...@@ -258,6 +263,17 @@ def parse_args() -> Config:
action="store_true", action="store_true",
help="Configure as ECConnector consumer for receiving encoder embeddings (for PD workers)", help="Configure as ECConnector consumer for receiving encoder embeddings (for PD workers)",
) )
parser.add_argument(
"--omni",
action="store_true",
help="Run as vLLM-Omni worker for multi-stage pipelines (supports text-to-text, text-to-image, etc.)",
)
parser.add_argument(
"--stage-configs-path",
type=str,
default=None,
help="Path to vLLM-Omni stage configuration YAML file. Required for --omni.",
)
parser.add_argument( parser.add_argument(
"--store-kv", "--store-kv",
type=str, type=str,
...@@ -379,6 +395,13 @@ def parse_args() -> Config: ...@@ -379,6 +395,13 @@ def parse_args() -> Config:
"Specify a shared storage path for encoder cache." "Specify a shared storage path for encoder cache."
) )
# Validate omni worker requirements
if args.omni and not args.stage_configs_path:
raise ValueError(
"--stage-configs-path is required when using --omni. "
"Specify a YAML file containing stage configurations for the multi-stage pipeline."
)
# Set component and endpoint based on worker type # Set component and endpoint based on worker type
if args.multimodal_processor or args.ec_processor: if args.multimodal_processor or args.ec_processor:
config.component = "processor" config.component = "processor"
...@@ -399,6 +422,10 @@ def parse_args() -> Config: ...@@ -399,6 +422,10 @@ def parse_args() -> Config:
# Multimodal prefill worker stays as "backend" to maintain encoder connection # Multimodal prefill worker stays as "backend" to maintain encoder connection
config.component = "backend" config.component = "backend"
config.endpoint = "generate" config.endpoint = "generate"
elif args.omni:
# Omni worker uses "backend" component for multi-stage pipeline orchestration
config.component = "backend"
config.endpoint = "generate"
elif args.is_prefill_worker: elif args.is_prefill_worker:
config.component = "prefill" config.component = "prefill"
config.endpoint = "generate" config.endpoint = "generate"
...@@ -428,6 +455,8 @@ def parse_args() -> Config: ...@@ -428,6 +455,8 @@ def parse_args() -> Config:
config.ec_storage_path = args.ec_storage_path config.ec_storage_path = args.ec_storage_path
config.ec_extra_config = args.ec_extra_config config.ec_extra_config = args.ec_extra_config
config.ec_consumer_mode = args.ec_consumer_mode config.ec_consumer_mode = args.ec_consumer_mode
config.omni = args.omni
config.stage_configs_path = args.stage_configs_path
config.store_kv = args.store_kv config.store_kv = args.store_kv
config.request_plane = args.request_plane config.request_plane = args.request_plane
config.event_plane = args.event_plane config.event_plane = args.event_plane
......
...@@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) ...@@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm_omni.entrypoints import AsyncOmni
def _get_bos_token_id_from_engine(engine_client) -> int: def _get_bos_token_id_from_engine(engine_client) -> int:
...@@ -118,3 +119,78 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload): ...@@ -118,3 +119,78 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload):
""" """
self.default_payload = _make_default_payload(engine_client, use_text_input) self.default_payload = _make_default_payload(engine_client, use_text_input)
super().__init__() super().__init__()
async def get_bos_token_from_omni(async_omni: "AsyncOmni") -> int:
"""
Extract BOS token ID from AsyncOmni orchestrator's tokenizer.
Args:
async_omni: AsyncOmni orchestrator instance
Returns:
BOS token ID from the model's tokenizer, or 1 as fallback
"""
if async_omni is None:
return 1
try:
tokenizer = await async_omni.get_tokenizer()
if tokenizer and hasattr(tokenizer, "bos_token_id"):
bos_token_id = tokenizer.bos_token_id
if bos_token_id is not None:
logger.info(
f"Using model's BOS token ID for Omni health check: {bos_token_id}"
)
return int(bos_token_id)
except Exception as e:
logger.debug(f"Failed to get BOS token from AsyncOmni: {e}")
logger.debug("Using default BOS token ID (1) for Omni health check")
return 1
class VllmOmniHealthCheckPayload(HealthCheckPayload):
"""
vLLM-Omni-specific health check payload.
Unlike standard vLLM workers, Omni workers use AsyncOmni which requires
async access to the tokenizer. Use the async create() classmethod to
properly initialize with the model's BOS token.
"""
def __init__(self, bos_token_id: int = 1):
"""
Initialize vLLM-Omni health check payload with BOS token.
Args:
bos_token_id: BOS token ID from the model, or default to 1.
"""
self.default_payload = {
"token_ids": [bos_token_id],
"sampling_options": {
"temperature": 0.0,
},
"stop_conditions": {
"max_tokens": 1,
"stop": None,
"stop_token_ids": None,
"include_stop_str_in_output": False,
"ignore_eos": False,
},
}
super().__init__()
@classmethod
async def create(cls, async_omni: "AsyncOmni") -> "VllmOmniHealthCheckPayload":
"""
Create VllmOmniHealthCheckPayload by extracting BOS token from AsyncOmni.
Args:
async_omni: AsyncOmni orchestrator instance
Returns:
VllmOmniHealthCheckPayload instance with proper BOS token
"""
bos_token_id = await get_bos_token_from_omni(async_omni)
return cls(bos_token_id)
...@@ -52,7 +52,11 @@ from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config ...@@ -52,7 +52,11 @@ from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config
from .args import Config, overwrite_args, parse_args from .args import Config, overwrite_args, parse_args
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
from .health_check import VllmHealthCheckPayload, VllmPrefillHealthCheckPayload from .health_check import (
VllmHealthCheckPayload,
VllmOmniHealthCheckPayload,
VllmPrefillHealthCheckPayload,
)
from .publisher import StatLoggerFactory from .publisher import StatLoggerFactory
configure_dynamo_logging() configure_dynamo_logging()
...@@ -261,6 +265,9 @@ async def worker(): ...@@ -261,6 +265,9 @@ async def worker():
runtime, config, shutdown_event, pre_created_engine=pre_created_engine runtime, config, shutdown_event, pre_created_engine=pre_created_engine
) )
logger.debug("init_multimodal_worker completed") logger.debug("init_multimodal_worker completed")
elif config.omni:
await init_omni(runtime, config, shutdown_event)
logger.debug("init_omni completed")
elif config.is_prefill_worker: elif config.is_prefill_worker:
await init_prefill( await init_prefill(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine runtime, config, shutdown_event, pre_created_engine=pre_created_engine
...@@ -1224,6 +1231,80 @@ async def init_multimodal_worker( ...@@ -1224,6 +1231,80 @@ async def init_multimodal_worker(
handler.cleanup() handler.cleanup()
async def init_omni(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""
Initialize Omni worker for text-to-text generation using vLLM-Omni orchestrator.
Uses vLLM-Omni's Omni class for single-stage text generation pipeline.
For now, supports text-to-text only (no multimodal).
"""
# Lazy import to avoid loading vllm-omni unless explicitly needed
from dynamo.vllm.omni import OmniHandler
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
# Load default sampling params from model config (same as other workers)
default_sampling_params = (
config.engine_args.create_model_config().get_diff_sampling_param()
)
logger.info(f"Loaded default sampling params: {default_sampling_params}")
# Initialize OmniHandler with Omni orchestrator
handler = OmniHandler(
runtime=runtime,
component=component,
config=config,
default_sampling_params=default_sampling_params,
shutdown_event=shutdown_event,
)
logger.info(f"Omni worker initialized for model: {config.model}")
# Set up metrics collection for vLLM and LMCache metrics
setup_metrics_collection(config, generate_endpoint, logger)
# Handle non-leader nodes - don't serve endpoints
if config.engine_args.data_parallel_rank:
await _handle_non_leader_node(config.engine_args.data_parallel_rank)
return
# TODO: extend for multi-stage pipelines
# Register as Chat endpoint for text-to-text generation
# Use Tokens input since we're doing token-based processing
await register_llm(
ModelInput.Tokens,
ModelType.Chat,
generate_endpoint,
config.model,
config.served_model_name,
kv_cache_block_size=config.engine_args.block_size,
)
logger.info("Starting to serve Omni worker endpoint...")
# Create health check payload (extracts BOS token from AsyncOmni)
health_check_payload = (
await VllmOmniHealthCheckPayload.create(handler.engine_client)
).to_dict()
try:
await generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", config.served_model_name or config.model)],
health_check_payload=health_check_payload,
)
except Exception as e:
logger.error(f"Failed to serve Omni endpoint: {e}")
raise
finally:
logger.debug("Cleaning up Omni worker")
handler.cleanup()
def main(): def main():
uvloop.run(worker()) uvloop.run(worker())
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""vLLM-Omni integration for Dynamo."""
from .omni_handler import OmniHandler
__all__ = ["OmniHandler"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Omni handler for text-to-text generation using vLLM-Omni orchestrator."""
import asyncio
import logging
from typing import Any, AsyncGenerator, Dict
from vllm import SamplingParams
from vllm.inputs import TokensPrompt
from vllm_omni.entrypoints import AsyncOmni
from dynamo.vllm.handlers import BaseWorkerHandler, build_sampling_params
logger = logging.getLogger(__name__)
class OmniHandler(BaseWorkerHandler):
"""Handler for multi-stage pipelines using vLLM-Omni's AsyncOmni orchestrator."""
def __init__(
self,
runtime,
component,
config,
default_sampling_params: Dict[str, Any],
shutdown_event: asyncio.Event | None = None,
):
"""Initialize handler with AsyncOmni orchestrator."""
logger.info(
f"Initializing OmniHandler for multi-stage pipelines with model: {config.model}"
)
# Initialize AsyncOmni with stage configuration
# Note: stage_configs_path is validated as required in args.py
logger.info(f"Using stage config from: {config.stage_configs_path}")
omni_kwargs = {
"model": config.model,
"trust_remote_code": config.engine_args.trust_remote_code,
"stage_configs_path": config.stage_configs_path,
}
self.engine_client = AsyncOmni(**omni_kwargs)
# Initialize attributes needed from BaseWorkerHandler
# We don't call super().__init__() because VllmEngineMonitor expects AsyncLLM,
# but AsyncOmni manages its own engines internally
# TODO: Kv publishers not supported yet
# TODO: Adopt to baseworker initialization pattern
self.default_sampling_params = default_sampling_params
self.config = config
self.model_max_len = config.engine_args.max_model_len
self.shutdown_event = shutdown_event
logger.info("OmniHandler initialized successfully for text-to-text generation")
async def generate(
self, request: Dict[str, Any], context
) -> AsyncGenerator[Dict, None]:
"""Generate text using AsyncOmni orchestrator. Currently supports text-to-text only."""
request_id = context.id()
logger.debug(f"Omni Request ID: {request_id}")
# Extract token_ids from internal protocol format
token_ids = request.get("token_ids")
if not token_ids:
logger.error(f"Request {request_id}: No token_ids found in request")
yield {
"finish_reason": "error: No token_ids in request",
"token_ids": [],
}
return
logger.info(
f"Request {request_id}: Generating text for {len(token_ids)} input tokens"
)
# Build sampling parameters from request
sampling_params = self._build_sampling_params(request)
sampling_params_list = [sampling_params]
tokens_prompt: TokensPrompt = {
"prompt_token_ids": token_ids,
}
async with self._abort_monitor(context, request_id):
try:
num_output_tokens_so_far = 0
async for stage_output in self.engine_client.generate(
prompt=tokens_prompt, # Pass TokensPrompt format
request_id=request_id,
sampling_params_list=sampling_params_list,
):
# stage_output is OmniRequestOutput
# For text generation: stage_output.request_output is a single vLLM RequestOutput
if (
stage_output.final_output_type == "text"
and stage_output.request_output
):
vllm_output = stage_output.request_output
if not vllm_output.outputs:
logger.warning(f"Request {request_id} returned no outputs")
yield {
"finish_reason": "error: No outputs from vLLM engine",
"token_ids": [],
}
break
output = vllm_output.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = self._normalize_finish_reason(
output.finish_reason
)
out["completion_usage"] = self._build_completion_usage(
vllm_output
)
logger.debug(
f"Completed generation for request {request_id}: "
f"{next_total_toks} output tokens, finish_reason={output.finish_reason}"
)
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
except GeneratorExit:
# Shutdown was triggered during generation
logger.info(f"Request {request_id} aborted due to shutdown")
raise
except Exception as e:
logger.error(f"Error during generation for request {request_id}: {e}")
yield {
"finish_reason": f"error: {str(e)}",
"token_ids": [],
}
def _build_sampling_params(self, request: Dict[str, Any]) -> SamplingParams:
"""Build sampling params using shared handler utility."""
return build_sampling_params(
request, self.default_sampling_params, self.model_max_len
)
def cleanup(self):
"""Cleanup AsyncOmni orchestrator resources."""
try:
if hasattr(self, "engine_client"):
self.engine_client.close()
logger.info("AsyncOmni orchestrator closed")
except Exception as e:
logger.error(f"Error closing AsyncOmni orchestrator: {e}")
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Launch script for testing vLLM-Omni integration with text-to-text generation
# This script starts an aggregated (frontend + omni worker) deployment for testing
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default model - Qwen2.5-Omni-7B for text-to-text
MODEL="${MODEL:-Qwen/Qwen2.5-Omni-7B}"
# Stage config path - use single-stage LLM config for text-to-text
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
STAGE_CONFIG="${STAGE_CONFIG:-$SCRIPT_DIR/stage_configs/single_stage_llm.yaml}"
# Parse command line arguments
EXTRA_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL="$2"
shift 2
;;
*)
EXTRA_ARGS+=("$1")
shift
;;
esac
done
echo "=========================================="
echo "Starting vLLM-Omni Worker (Text-to-Text)"
echo "Model: $MODEL"
echo "Stage Config: $STAGE_CONFIG"
echo "=========================================="
# Run ingress (frontend)
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
echo "Starting frontend on port ${DYN_HTTP_PORT:-8000}..."
python -m dynamo.frontend &
FRONTEND_PID=$!
# Wait a bit for frontend to start
sleep 2
echo "Starting Omni worker..."
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \
python -m dynamo.vllm \
--model "$MODEL" \
--omni \
--stage-configs-path "$STAGE_CONFIG" \
--connector none \
"${EXTRA_ARGS[@]}"
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Single-stage LLM configuration for text-to-text generation
# This is a minimal config for testing vLLM-Omni with standard LLM models
stage_args:
- stage_id: 0
# stage_type defaults to "llm" when not specified
runtime:
devices: "0" # Single GPU
max_batch_size: 32
engine_args:
# Model stage identifier (required by vLLM-Omni)
# For Qwen2.5-Omni: thinker (comprehension), talker (enhanced gen), code2wav (audio)
# For single-stage text generation, use "thinker"
model_stage: thinker
model_arch: Qwen2_5OmniForConditionalGeneration
worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
# Engine configuration
tensor_parallel_size: 1
pipeline_parallel_size: 1
trust_remote_code: true
enforce_eager: true # Required by vLLM-Omni currently
gpu_memory_utilization: 0.8
enable_prefix_caching: false
max_num_batched_tokens: 32768
engine_output_type: latent
# Enable distributed backend for TP>1
distributed_executor_backend: "mp"
is_comprehension: true # Thinker stage is the comprehension/text generation stage
final_output: true
final_output_type: text
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
max_tokens: 2048
repetition_penalty: 1.1
seed: 42
detokenize: false # Token-based processing for Dynamo
...@@ -58,6 +58,7 @@ vllm = [ ...@@ -58,6 +58,7 @@ vllm = [
"uvloop", "uvloop",
"nixl[cu12]<=0.9.0", "nixl[cu12]<=0.9.0",
"vllm[flashinfer,runai]==0.14.1", "vllm[flashinfer,runai]==0.14.1",
"vllm-omni==0.14.0",
] ]
sglang = [ sglang = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment