# adapted from sglang and fastvideo # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, field from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType @dataclass class OmniDiffusionRequest: """ Complete state passed through the pipeline execution. This dataclass contains the prompts and sampling parameters for the diffusion pipeline execution. It also contains a request_id for other components to trace this request and its outputs. """ # TODO(will): double check that args are separate from server_args # properly. Also maybe think about providing an abstraction for pipeline # specific arguments. # data_type: DataType prompts: list[OmniPromptType] # Actually supporting str-based prompts sampling_params: OmniDiffusionSamplingParams request_ids: list[str] = field(default_factory=list) def __post_init__(self): """Initialize dependent fields after dataclass initialization.""" # Set do_classifier_free_guidance based on guidance scale and negative prompt if self.sampling_params.guidance_scale > 1.0 and any( (not isinstance(p, str) and p.get("negative_prompt")) for p in self.prompts ): self.sampling_params.do_classifier_free_guidance = True if self.sampling_params.guidance_scale_2 is None: self.sampling_params.guidance_scale_2 = self.sampling_params.guidance_scale # The dataclass default value is 0 (false-like), used to detect whether user explicitly provides this value # After this check is done, reset this value to old default 1 if self.sampling_params.guidance_scale: self.sampling_params.guidance_scale_provided = True else: self.sampling_params.guidance_scale = 1.0