base_runner.py 4.33 KB
Newer Older
PengGao's avatar
PengGao committed
1
from abc import ABC, abstractmethod
PengGao's avatar
PengGao committed
2
3
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union

PengGao's avatar
PengGao committed
4
5
6
from lightx2v.utils.utils import save_videos_grid


7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class TransformerModel(Protocol):
    """Protocol for transformer models"""

    def set_scheduler(self, scheduler: Any) -> None: ...
    def scheduler(self) -> Any: ...


class TextEncoderModel(Protocol):
    """Protocol for text encoder models"""

    def infer(self, texts: List[str], config: Dict[str, Any]) -> Any: ...


class ImageEncoderModel(Protocol):
    """Protocol for image encoder models"""

    def encode(self, image: Any) -> Any: ...


class VAEModel(Protocol):
    """Protocol for VAE models"""

    def encode(self, image: Any) -> Tuple[Any, Dict[str, Any]]: ...
    def decode(self, latents: Any, generator: Optional[Any] = None, config: Optional[Dict[str, Any]] = None) -> Any: ...


PengGao's avatar
PengGao committed
33
34
35
36
37
38
39
40
41
42
class BaseRunner(ABC):
    """Abstract base class for all Runners

    Defines interface methods that all subclasses must implement
    """

    def __init__(self, config: Dict[str, Any]):
        self.config = config

    @abstractmethod
43
    def load_transformer(self) -> TransformerModel:
PengGao's avatar
PengGao committed
44
45
46
        """Load transformer model

        Returns:
47
            Loaded transformer model instance
PengGao's avatar
PengGao committed
48
49
50
51
        """
        pass

    @abstractmethod
52
    def load_text_encoder(self) -> Union[TextEncoderModel, List[TextEncoderModel]]:
PengGao's avatar
PengGao committed
53
54
55
        """Load text encoder

        Returns:
56
            Text encoder instance or list of text encoder instances
PengGao's avatar
PengGao committed
57
58
59
60
        """
        pass

    @abstractmethod
61
    def load_image_encoder(self) -> Optional[ImageEncoderModel]:
PengGao's avatar
PengGao committed
62
63
64
        """Load image encoder

        Returns:
65
            Image encoder instance or None if not needed
PengGao's avatar
PengGao committed
66
67
68
69
        """
        pass

    @abstractmethod
70
    def load_vae(self) -> Tuple[VAEModel, VAEModel]:
PengGao's avatar
PengGao committed
71
72
73
74
75
76
77
78
        """Load VAE encoder and decoder

        Returns:
            Tuple[vae_encoder, vae_decoder]: VAE encoder and decoder instances
        """
        pass

    @abstractmethod
79
    def run_image_encoder(self, img: Any) -> Any:
PengGao's avatar
PengGao committed
80
81
82
83
84
85
86
87
88
89
90
        """Run image encoder

        Args:
            img: Input image

        Returns:
            Image encoding result
        """
        pass

    @abstractmethod
91
    def run_vae_encoder(self, img: Any) -> Tuple[Any, Dict[str, Any]]:
PengGao's avatar
PengGao committed
92
93
94
95
96
97
        """Run VAE encoder

        Args:
            img: Input image

        Returns:
98
            Tuple of VAE encoding result and additional parameters
PengGao's avatar
PengGao committed
99
100
101
102
        """
        pass

    @abstractmethod
103
    def run_text_encoder(self, prompt: str, img: Optional[Any] = None) -> Any:
PengGao's avatar
PengGao committed
104
105
106
107
108
109
110
111
112
113
114
115
        """Run text encoder

        Args:
            prompt: Input text prompt
            img: Optional input image (for some models)

        Returns:
            Text encoding result
        """
        pass

    @abstractmethod
116
    def get_encoder_output_i2v(self, clip_encoder_out: Any, vae_encoder_out: Any, text_encoder_output: Any, img: Any) -> Dict[str, Any]:
PengGao's avatar
PengGao committed
117
118
119
120
        """Combine encoder outputs for i2v task

        Args:
            clip_encoder_out: CLIP encoder output
121
            vae_encoder_out: VAE encoder output
PengGao's avatar
PengGao committed
122
123
124
125
126
127
128
129
130
            text_encoder_output: Text encoder output
            img: Original image

        Returns:
            Combined encoder output dictionary
        """
        pass

    @abstractmethod
131
    def init_scheduler(self) -> None:
PengGao's avatar
PengGao committed
132
133
134
135
136
137
138
139
140
141
142
143
144
        """Initialize scheduler"""
        pass

    def set_target_shape(self) -> Dict[str, Any]:
        """Set target shape

        Subclasses can override this method to provide specific implementation

        Returns:
            Dictionary containing target shape information
        """
        return {}

145
    def save_video_func(self, images: Any) -> None:
PengGao's avatar
PengGao committed
146
147
148
149
150
151
152
153
154
        """Save video implementation

        Subclasses can override this method to customize save logic

        Args:
            images: Image sequence to save
        """
        save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8))

155
    def load_vae_decoder(self) -> VAEModel:
PengGao's avatar
PengGao committed
156
157
158
159
160
161
162
163
164
165
166
        """Load VAE decoder

        Default implementation: get decoder from load_vae method
        Subclasses can override this method to provide different loading logic

        Returns:
            VAE decoder instance
        """
        if not hasattr(self, "vae_decoder") or self.vae_decoder is None:
            _, self.vae_decoder = self.load_vae()
        return self.vae_decoder