base_runner.py 4.32 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
from abc import ABC
PengGao's avatar
PengGao committed
2

LiangLiu's avatar
LiangLiu committed
3
4
5
import torch
import torch.distributed as dist

PengGao's avatar
PengGao committed
6
7
8
9
10
11
12
13
14
from lightx2v.utils.utils import save_videos_grid


class BaseRunner(ABC):
    """Abstract base class for all Runners

    Defines interface methods that all subclasses must implement
    """

helloyongyang's avatar
helloyongyang committed
15
    def __init__(self, config):
PengGao's avatar
PengGao committed
16
        self.config = config
17
        self.vae_encoder_need_img_original = False
PengGao's avatar
PengGao committed
18

helloyongyang's avatar
helloyongyang committed
19
    def load_transformer(self):
PengGao's avatar
PengGao committed
20
21
22
        """Load transformer model

        Returns:
23
            Loaded transformer model instance
PengGao's avatar
PengGao committed
24
25
26
        """
        pass

helloyongyang's avatar
helloyongyang committed
27
    def load_text_encoder(self):
PengGao's avatar
PengGao committed
28
29
30
        """Load text encoder

        Returns:
31
            Text encoder instance or list of text encoder instances
PengGao's avatar
PengGao committed
32
33
34
        """
        pass

helloyongyang's avatar
helloyongyang committed
35
    def load_image_encoder(self):
PengGao's avatar
PengGao committed
36
37
38
        """Load image encoder

        Returns:
39
            Image encoder instance or None if not needed
PengGao's avatar
PengGao committed
40
41
42
        """
        pass

helloyongyang's avatar
helloyongyang committed
43
    def load_vae(self):
PengGao's avatar
PengGao committed
44
45
46
47
48
49
50
        """Load VAE encoder and decoder

        Returns:
            Tuple[vae_encoder, vae_decoder]: VAE encoder and decoder instances
        """
        pass

helloyongyang's avatar
helloyongyang committed
51
    def run_image_encoder(self, img):
PengGao's avatar
PengGao committed
52
53
54
55
56
57
58
59
60
61
        """Run image encoder

        Args:
            img: Input image

        Returns:
            Image encoding result
        """
        pass

helloyongyang's avatar
helloyongyang committed
62
    def run_vae_encoder(self, img):
PengGao's avatar
PengGao committed
63
64
65
66
67
68
        """Run VAE encoder

        Args:
            img: Input image

        Returns:
69
            Tuple of VAE encoding result and additional parameters
PengGao's avatar
PengGao committed
70
71
72
        """
        pass

helloyongyang's avatar
helloyongyang committed
73
    def run_text_encoder(self, prompt, img):
PengGao's avatar
PengGao committed
74
75
76
77
78
79
80
81
82
83
84
        """Run text encoder

        Args:
            prompt: Input text prompt
            img: Optional input image (for some models)

        Returns:
            Text encoding result
        """
        pass

helloyongyang's avatar
helloyongyang committed
85
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
PengGao's avatar
PengGao committed
86
87
88
89
        """Combine encoder outputs for i2v task

        Args:
            clip_encoder_out: CLIP encoder output
90
            vae_encoder_out: VAE encoder output
PengGao's avatar
PengGao committed
91
92
93
94
95
96
97
98
            text_encoder_output: Text encoder output
            img: Original image

        Returns:
            Combined encoder output dictionary
        """
        pass

helloyongyang's avatar
helloyongyang committed
99
    def init_scheduler(self):
PengGao's avatar
PengGao committed
100
101
102
        """Initialize scheduler"""
        pass

helloyongyang's avatar
helloyongyang committed
103
    def set_target_shape(self):
PengGao's avatar
PengGao committed
104
105
106
107
108
109
110
111
112
        """Set target shape

        Subclasses can override this method to provide specific implementation

        Returns:
            Dictionary containing target shape information
        """
        return {}

helloyongyang's avatar
helloyongyang committed
113
    def save_video_func(self, images):
PengGao's avatar
PengGao committed
114
115
116
117
118
119
120
121
122
        """Save video implementation

        Subclasses can override this method to customize save logic

        Args:
            images: Image sequence to save
        """
        save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8))

helloyongyang's avatar
helloyongyang committed
123
    def load_vae_decoder(self):
PengGao's avatar
PengGao committed
124
125
126
127
128
129
130
131
132
133
134
        """Load VAE decoder

        Default implementation: get decoder from load_vae method
        Subclasses can override this method to provide different loading logic

        Returns:
            VAE decoder instance
        """
        if not hasattr(self, "vae_decoder") or self.vae_decoder is None:
            _, self.vae_decoder = self.load_vae()
        return self.vae_decoder
helloyongyang's avatar
helloyongyang committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

    def get_video_segment_num(self):
        self.video_segment_num = 1

    def init_run(self):
        pass

    def init_run_segment(self, segment_idx):
        self.segment_idx = segment_idx

    def run_segment(self, total_steps=None):
        pass

    def end_run_segment(self):
        pass

    def end_run(self):
        pass
LiangLiu's avatar
LiangLiu committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

    def check_stop(self):
        """Check if the stop signal is received"""

        rank, world_size = 0, 1
        if dist.is_initialized():
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        signal_rank = world_size - 1

        stopped = 0
        if rank == signal_rank and hasattr(self, "stop_signal") and self.stop_signal:
            stopped = 1

        if world_size > 1:
            if rank == signal_rank:
                t = torch.tensor([stopped], dtype=torch.int32).to(device="cuda")
            else:
                t = torch.zeros(1, dtype=torch.int32, device="cuda")
            dist.broadcast(t, src=signal_rank)
            stopped = t.item()

        if stopped == 1:
            raise Exception(f"find rank: {rank} stop_signal, stop running, it's an expected behavior")