# SPDX-License-Identifier: Apache-2.0 """ Wan video diffusion pipeline implementation. This module contains an implementation of the Wan video diffusion pipeline using the modular pipeline architecture. """ from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines.composed_pipeline_base import ComposedPipelineBase from fastvideo.v1.pipelines.stages import (ConditioningStage, DecodingStage, DenoisingStage, InputValidationStage, LatentPreparationStage, TextEncodingStage, TimestepPreparationStage) # TODO(will): move PRECISION_TO_TYPE to better place logger = init_logger(__name__) class WanPipeline(ComposedPipelineBase): _required_config_modules = [ "text_encoder", "tokenizer", "vae", "transformer", "scheduler" ] def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): """Set up pipeline stages with proper dependency injection.""" self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage()) self.add_stage(stage_name="prompt_encoding_stage", stage=TextEncodingStage( text_encoders=[self.get_module("text_encoder")], tokenizers=[self.get_module("tokenizer")], )) self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) self.add_stage(stage_name="timestep_preparation_stage", stage=TimestepPreparationStage( scheduler=self.get_module("scheduler"))) self.add_stage(stage_name="latent_preparation_stage", stage=LatentPreparationStage( scheduler=self.get_module("scheduler"), transformer=self.get_module("transformer"))) self.add_stage(stage_name="denoising_stage", stage=DenoisingStage( transformer=self.get_module("transformer"), scheduler=self.get_module("scheduler"))) self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"))) EntryClass = WanPipeline