base_runner.py 4.53 KB
Newer Older
LiangLiu's avatar
LiangLiu committed
1
import os
helloyongyang's avatar
helloyongyang committed
2
from abc import ABC
PengGao's avatar
PengGao committed
3

LiangLiu's avatar
LiangLiu committed
4
5
6
import torch
import torch.distributed as dist

7
8
from lightx2v_platform.base.global_var import AI_DEVICE

PengGao's avatar
PengGao committed
9
10
11
12
13
14
15

class BaseRunner(ABC):
    """Abstract base class for all Runners

    Defines interface methods that all subclasses must implement
    """

helloyongyang's avatar
helloyongyang committed
16
    def __init__(self, config):
PengGao's avatar
PengGao committed
17
        self.config = config
18
        self.vae_encoder_need_img_original = False
19
        self.input_info = None
PengGao's avatar
PengGao committed
20

helloyongyang's avatar
helloyongyang committed
21
    def load_transformer(self):
PengGao's avatar
PengGao committed
22
23
24
        """Load transformer model

        Returns:
25
            Loaded transformer model instance
PengGao's avatar
PengGao committed
26
27
28
        """
        pass

helloyongyang's avatar
helloyongyang committed
29
    def load_text_encoder(self):
PengGao's avatar
PengGao committed
30
31
32
        """Load text encoder

        Returns:
33
            Text encoder instance or list of text encoder instances
PengGao's avatar
PengGao committed
34
35
36
        """
        pass

helloyongyang's avatar
helloyongyang committed
37
    def load_image_encoder(self):
PengGao's avatar
PengGao committed
38
39
40
        """Load image encoder

        Returns:
41
            Image encoder instance or None if not needed
PengGao's avatar
PengGao committed
42
43
44
        """
        pass

helloyongyang's avatar
helloyongyang committed
45
    def load_vae(self):
PengGao's avatar
PengGao committed
46
47
48
49
50
51
52
        """Load VAE encoder and decoder

        Returns:
            Tuple[vae_encoder, vae_decoder]: VAE encoder and decoder instances
        """
        pass

helloyongyang's avatar
helloyongyang committed
53
    def run_image_encoder(self, img):
PengGao's avatar
PengGao committed
54
55
56
57
58
59
60
61
62
63
        """Run image encoder

        Args:
            img: Input image

        Returns:
            Image encoding result
        """
        pass

helloyongyang's avatar
helloyongyang committed
64
    def run_vae_encoder(self, img):
PengGao's avatar
PengGao committed
65
66
67
68
69
70
        """Run VAE encoder

        Args:
            img: Input image

        Returns:
71
            Tuple of VAE encoding result and additional parameters
PengGao's avatar
PengGao committed
72
73
74
        """
        pass

helloyongyang's avatar
helloyongyang committed
75
    def run_text_encoder(self, prompt, img):
PengGao's avatar
PengGao committed
76
77
78
79
80
81
82
83
84
85
86
        """Run text encoder

        Args:
            prompt: Input text prompt
            img: Optional input image (for some models)

        Returns:
            Text encoding result
        """
        pass

helloyongyang's avatar
helloyongyang committed
87
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
PengGao's avatar
PengGao committed
88
89
90
91
        """Combine encoder outputs for i2v task

        Args:
            clip_encoder_out: CLIP encoder output
92
            vae_encoder_out: VAE encoder output
PengGao's avatar
PengGao committed
93
94
95
96
97
98
99
100
            text_encoder_output: Text encoder output
            img: Original image

        Returns:
            Combined encoder output dictionary
        """
        pass

helloyongyang's avatar
helloyongyang committed
101
    def init_scheduler(self):
PengGao's avatar
PengGao committed
102
103
104
        """Initialize scheduler"""
        pass

helloyongyang's avatar
helloyongyang committed
105
    def load_vae_decoder(self):
PengGao's avatar
PengGao committed
106
107
108
109
110
111
112
113
114
115
116
        """Load VAE decoder

        Default implementation: get decoder from load_vae method
        Subclasses can override this method to provide different loading logic

        Returns:
            VAE decoder instance
        """
        if not hasattr(self, "vae_decoder") or self.vae_decoder is None:
            _, self.vae_decoder = self.load_vae()
        return self.vae_decoder
helloyongyang's avatar
helloyongyang committed
117
118
119
120
121
122
123
124
125
126

    def get_video_segment_num(self):
        self.video_segment_num = 1

    def init_run(self):
        pass

    def init_run_segment(self, segment_idx):
        self.segment_idx = segment_idx

PengGao's avatar
PengGao committed
127
    def run_segment(self, segment_idx=0):
helloyongyang's avatar
helloyongyang committed
128
129
        pass

130
    def end_run_segment(self, segment_idx=None):
131
        self.gen_video_final = self.gen_video
helloyongyang's avatar
helloyongyang committed
132
133
134

    def end_run(self):
        pass
LiangLiu's avatar
LiangLiu committed
135
136
137
138
139
140
141
142

    def check_stop(self):
        """Check if the stop signal is received"""

        rank, world_size = 0, 1
        if dist.is_initialized():
            rank = dist.get_rank()
            world_size = dist.get_world_size()
LiangLiu's avatar
LiangLiu committed
143
144
        stop_rank = int(os.getenv("WORKER_RANK", "0")) % world_size  # same as worker hub target_rank
        pause_rank = int(os.getenv("READER_RANK", "0")) % world_size  # same as va_reader target_rank
LiangLiu's avatar
LiangLiu committed
145

LiangLiu's avatar
LiangLiu committed
146
147
        stopped, paused = 0, 0
        if rank == stop_rank and hasattr(self, "stop_signal") and self.stop_signal:
LiangLiu's avatar
LiangLiu committed
148
            stopped = 1
LiangLiu's avatar
LiangLiu committed
149
150
        if rank == pause_rank and hasattr(self, "pause_signal") and self.pause_signal:
            paused = 1
LiangLiu's avatar
LiangLiu committed
151
152

        if world_size > 1:
LiangLiu's avatar
LiangLiu committed
153
154
            if rank == stop_rank:
                t1 = torch.tensor([stopped], dtype=torch.int32).to(device=AI_DEVICE)
LiangLiu's avatar
LiangLiu committed
155
            else:
LiangLiu's avatar
LiangLiu committed
156
157
158
159
160
161
162
163
164
                t1 = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE)
            if rank == pause_rank:
                t2 = torch.tensor([paused], dtype=torch.int32).to(device=AI_DEVICE)
            else:
                t2 = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE)
            dist.broadcast(t1, src=stop_rank)
            dist.broadcast(t2, src=pause_rank)
            stopped = t1.item()
            paused = t2.item()
LiangLiu's avatar
LiangLiu committed
165
166
167

        if stopped == 1:
            raise Exception(f"find rank: {rank} stop_signal, stop running, it's an expected behavior")
LiangLiu's avatar
LiangLiu committed
168
169
        if paused == 1:
            raise Exception(f"find rank: {rank} pause_signal, pause running, it's an expected behavior")