from abc import ABC from lightx2v.utils.utils import save_videos_grid class BaseRunner(ABC): """Abstract base class for all Runners Defines interface methods that all subclasses must implement """ def __init__(self, config): self.config = config def load_transformer(self): """Load transformer model Returns: Loaded transformer model instance """ pass def load_text_encoder(self): """Load text encoder Returns: Text encoder instance or list of text encoder instances """ pass def load_image_encoder(self): """Load image encoder Returns: Image encoder instance or None if not needed """ pass def load_vae(self): """Load VAE encoder and decoder Returns: Tuple[vae_encoder, vae_decoder]: VAE encoder and decoder instances """ pass def run_image_encoder(self, img): """Run image encoder Args: img: Input image Returns: Image encoding result """ pass def run_vae_encoder(self, img): """Run VAE encoder Args: img: Input image Returns: Tuple of VAE encoding result and additional parameters """ pass def run_text_encoder(self, prompt, img): """Run text encoder Args: prompt: Input text prompt img: Optional input image (for some models) Returns: Text encoding result """ pass def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img): """Combine encoder outputs for i2v task Args: clip_encoder_out: CLIP encoder output vae_encoder_out: VAE encoder output text_encoder_output: Text encoder output img: Original image Returns: Combined encoder output dictionary """ pass def init_scheduler(self): """Initialize scheduler""" pass def set_target_shape(self): """Set target shape Subclasses can override this method to provide specific implementation Returns: Dictionary containing target shape information """ return {} def save_video_func(self, images): """Save video implementation Subclasses can override this method to customize save logic Args: images: Image sequence to save """ save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8)) def load_vae_decoder(self): """Load VAE decoder Default implementation: get decoder from load_vae method Subclasses can override this method to provide different loading logic Returns: VAE decoder instance """ if not hasattr(self, "vae_decoder") or self.vae_decoder is None: _, self.vae_decoder = self.load_vae() return self.vae_decoder def get_video_segment_num(self): self.video_segment_num = 1 def init_run(self): pass def init_run_segment(self, segment_idx): self.segment_idx = segment_idx def run_segment(self, total_steps=None): pass def end_run_segment(self): pass def end_run(self): pass