"vscode:/vscode.git/clone" did not exist on "59eaa0295af78870334776d256605b80e988b404"
base_runner.py 3.41 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
from abc import ABC
PengGao's avatar
PengGao committed
2

PengGao's avatar
PengGao committed
3
4
5
6
7
8
9
10
11
from lightx2v.utils.utils import save_videos_grid


class BaseRunner(ABC):
    """Abstract base class for all Runners

    Defines interface methods that all subclasses must implement
    """

helloyongyang's avatar
helloyongyang committed
12
    def __init__(self, config):
PengGao's avatar
PengGao committed
13
14
        self.config = config

helloyongyang's avatar
helloyongyang committed
15
    def load_transformer(self):
PengGao's avatar
PengGao committed
16
17
18
        """Load transformer model

        Returns:
19
            Loaded transformer model instance
PengGao's avatar
PengGao committed
20
21
22
        """
        pass

helloyongyang's avatar
helloyongyang committed
23
    def load_text_encoder(self):
PengGao's avatar
PengGao committed
24
25
26
        """Load text encoder

        Returns:
27
            Text encoder instance or list of text encoder instances
PengGao's avatar
PengGao committed
28
29
30
        """
        pass

helloyongyang's avatar
helloyongyang committed
31
    def load_image_encoder(self):
PengGao's avatar
PengGao committed
32
33
34
        """Load image encoder

        Returns:
35
            Image encoder instance or None if not needed
PengGao's avatar
PengGao committed
36
37
38
        """
        pass

helloyongyang's avatar
helloyongyang committed
39
    def load_vae(self):
PengGao's avatar
PengGao committed
40
41
42
43
44
45
46
        """Load VAE encoder and decoder

        Returns:
            Tuple[vae_encoder, vae_decoder]: VAE encoder and decoder instances
        """
        pass

helloyongyang's avatar
helloyongyang committed
47
    def run_image_encoder(self, img):
PengGao's avatar
PengGao committed
48
49
50
51
52
53
54
55
56
57
        """Run image encoder

        Args:
            img: Input image

        Returns:
            Image encoding result
        """
        pass

helloyongyang's avatar
helloyongyang committed
58
    def run_vae_encoder(self, img):
PengGao's avatar
PengGao committed
59
60
61
62
63
64
        """Run VAE encoder

        Args:
            img: Input image

        Returns:
65
            Tuple of VAE encoding result and additional parameters
PengGao's avatar
PengGao committed
66
67
68
        """
        pass

helloyongyang's avatar
helloyongyang committed
69
    def run_text_encoder(self, prompt, img):
PengGao's avatar
PengGao committed
70
71
72
73
74
75
76
77
78
79
80
        """Run text encoder

        Args:
            prompt: Input text prompt
            img: Optional input image (for some models)

        Returns:
            Text encoding result
        """
        pass

helloyongyang's avatar
helloyongyang committed
81
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
PengGao's avatar
PengGao committed
82
83
84
85
        """Combine encoder outputs for i2v task

        Args:
            clip_encoder_out: CLIP encoder output
86
            vae_encoder_out: VAE encoder output
PengGao's avatar
PengGao committed
87
88
89
90
91
92
93
94
            text_encoder_output: Text encoder output
            img: Original image

        Returns:
            Combined encoder output dictionary
        """
        pass

helloyongyang's avatar
helloyongyang committed
95
    def init_scheduler(self):
PengGao's avatar
PengGao committed
96
97
98
        """Initialize scheduler"""
        pass

helloyongyang's avatar
helloyongyang committed
99
    def set_target_shape(self):
PengGao's avatar
PengGao committed
100
101
102
103
104
105
106
107
108
        """Set target shape

        Subclasses can override this method to provide specific implementation

        Returns:
            Dictionary containing target shape information
        """
        return {}

helloyongyang's avatar
helloyongyang committed
109
    def save_video_func(self, images):
PengGao's avatar
PengGao committed
110
111
112
113
114
115
116
117
118
        """Save video implementation

        Subclasses can override this method to customize save logic

        Args:
            images: Image sequence to save
        """
        save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8))

helloyongyang's avatar
helloyongyang committed
119
    def load_vae_decoder(self):
PengGao's avatar
PengGao committed
120
121
122
123
124
125
126
127
128
129
130
        """Load VAE decoder

        Default implementation: get decoder from load_vae method
        Subclasses can override this method to provide different loading logic

        Returns:
            VAE decoder instance
        """
        if not hasattr(self, "vae_decoder") or self.vae_decoder is None:
            _, self.vae_decoder = self.load_vae()
        return self.vae_decoder
helloyongyang's avatar
helloyongyang committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

    def get_video_segment_num(self):
        self.video_segment_num = 1

    def init_run(self):
        pass

    def init_run_segment(self, segment_idx):
        self.segment_idx = segment_idx

    def run_segment(self, total_steps=None):
        pass

    def end_run_segment(self):
        pass

    def end_run(self):
        pass