# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os import time from collections.abc import Iterable from typing import Any import PIL.Image from vllm.logger import init_logger from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.diffusion.executor.abstract import DiffusionExecutor from vllm_omni.diffusion.registry import ( DiffusionModelRegistry, get_diffusion_post_process_func, get_diffusion_pre_process_func, ) from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) def supports_image_input(model_class_name: str) -> bool: model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) if model_cls is None: return False return bool(getattr(model_cls, "support_image_input", False)) def image_color_format(model_class_name: str) -> str: model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) return getattr(model_cls, "color_format", "RGB") def supports_audio_output(model_class_name: str) -> bool: model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) if model_cls is None: return False return bool(getattr(model_cls, "support_audio_output", False)) class DiffusionEngine: """The diffusion engine for vLLM-Omni diffusion models.""" def __init__(self, od_config: OmniDiffusionConfig): """Initialize the diffusion engine. Args: config: The configuration for the diffusion engine. """ self.od_config = od_config self.post_process_func = get_diffusion_post_process_func(od_config) self.pre_process_func = get_diffusion_pre_process_func(od_config) executor_class = DiffusionExecutor.get_class(od_config) self.executor = executor_class(od_config) try: self._dummy_run() except Exception as e: logger.error(f"Dummy run failed: {e}") self.close() raise e def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: # Apply pre-processing if available if self.pre_process_func is not None: preprocess_start_time = time.time() request = self.pre_process_func(request) preprocess_time = time.time() - preprocess_start_time logger.info(f"Pre-processing completed in {preprocess_time:.4f} seconds") output = self.add_req_and_wait_for_response(request) if output.error: raise Exception(f"{output.error}") logger.info("Generation completed successfully.") if output.output is None: logger.warning("Output is None, returning empty OmniRequestOutput") return [ OmniRequestOutput.from_diffusion( request_id=request.request_ids[i] if i < len(request.request_ids) else "", images=[], prompt=prompt, metrics={}, latents=None, ) for i, prompt in enumerate(request.prompts) ] postprocess_start_time = time.time() outputs = self.post_process_func(output.output) if self.post_process_func is not None else output.output postprocess_time = time.time() - postprocess_start_time logger.info(f"Post-processing completed in {postprocess_time:.4f} seconds") # Convert to OmniRequestOutput format # Ensure outputs is a list if not isinstance(outputs, list): outputs = [outputs] if outputs is not None else [] # Handle single request or multiple requests if len(request.prompts) == 1: # Single request: return single OmniRequestOutput prompt = request.prompts[0] request_id = request.request_ids[0] if request.request_ids else "" metrics = {} if output.trajectory_timesteps is not None: metrics["trajectory_timesteps"] = output.trajectory_timesteps if supports_audio_output(self.od_config.model_class_name): audio_payload = outputs[0] if len(outputs) == 1 else outputs return [ OmniRequestOutput.from_diffusion( request_id=request_id, images=[], prompt=prompt, metrics=metrics, latents=output.trajectory_latents, multimodal_output={"audio": audio_payload}, final_output_type="audio", ), ] else: return [ OmniRequestOutput.from_diffusion( request_id=request_id, images=outputs, prompt=prompt, metrics=metrics, latents=output.trajectory_latents, ), ] else: # Multiple requests: return list of OmniRequestOutput # Split images based on num_outputs_per_prompt for each request results = [] output_idx = 0 for i, prompt in enumerate(request.prompts): request_id = request.request_ids[i] if i < len(request.request_ids) else "" # Get images for this request num_outputs = request.sampling_params.num_outputs_per_prompt request_outputs = outputs[output_idx : output_idx + num_outputs] if output_idx < len(outputs) else [] output_idx += num_outputs metrics = {} if output.trajectory_timesteps is not None: metrics["trajectory_timesteps"] = output.trajectory_timesteps if supports_audio_output(self.od_config.model_class_name): audio_payload = request_outputs[0] if len(request_outputs) == 1 else request_outputs results.append( OmniRequestOutput.from_diffusion( request_id=request_id, images=[], prompt=prompt, metrics=metrics, latents=output.trajectory_latents, multimodal_output={"audio": audio_payload}, final_output_type="audio", ) ) else: results.append( OmniRequestOutput.from_diffusion( request_id=request_id, images=request_outputs, prompt=prompt, metrics=metrics, latents=output.trajectory_latents, ) ) return results @staticmethod def make_engine(config: OmniDiffusionConfig) -> "DiffusionEngine": """Factory method to create a DiffusionEngine instance. Args: config: The configuration for the diffusion engine. Returns: An instance of DiffusionEngine. """ return DiffusionEngine(config) def add_req_and_wait_for_response(self, request: OmniDiffusionRequest): return self.executor.add_req(request) def start_profile(self, trace_filename: str | None = None) -> None: """ Start torch profiling on all diffusion workers. Creates a directory (if needed) and sets up a base filename template for per-rank profiler traces (typically saved as