cogvidex_runner.py 2.66 KB
Newer Older
Watebear's avatar
Watebear committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from diffusers.utils import export_to_video
import imageio
import numpy as np

from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model import T5EncoderModel_v1_1_xxl
from lightx2v.models.networks.cogvideox.model import CogvideoxModel
from lightx2v.models.video_encoders.hf.cogvideox.model import CogvideoxVAE
from lightx2v.models.schedulers.cogvideox.scheduler import CogvideoxXDPMScheduler


@RUNNER_REGISTER("cogvideox")
class CogvideoxRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

    @ProfilingContext("Load models")
    def load_model(self):
        text_encoder = T5EncoderModel_v1_1_xxl(self.config)
        text_encoders = [text_encoder]
        model = CogvideoxModel(self.config)
        vae_model = CogvideoxVAE(self.config)
        image_encoder = None
        return model, text_encoders, vae_model, image_encoder

    def init_scheduler(self):
        scheduler = CogvideoxXDPMScheduler(self.config)
        self.model.set_scheduler(scheduler)

    def run_text_encoder(self, text, text_encoders, config, image_encoder_output):
        text_encoder_output = {}
        n_prompt = config.get("negative_prompt", "")
        context = text_encoders[0].infer([text], config)
        context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], config)
        text_encoder_output["context"] = context
        text_encoder_output["context_null"] = context_null
        return text_encoder_output

    def set_target_shape(self):
        num_frames = self.config.target_video_length
        latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1
        additional_frames = 0
        patch_size_t = self.config.patch_size_t
        if patch_size_t is not None and latent_frames % patch_size_t != 0:
            additional_frames = patch_size_t - latent_frames % patch_size_t
            num_frames += additional_frames * self.config.vae_scale_factor_temporal
        self.config.target_shape = (
            self.config.batch_size,
            (num_frames - 1) // self.config.vae_scale_factor_temporal + 1,
            self.config.latent_channels,
            self.config.height // self.config.vae_scale_factor_spatial,
            self.config.width // self.config.vae_scale_factor_spatial,
        )

    def save_video(self, images):
        with imageio.get_writer(self.config.save_video_path, fps=16) as writer:
            for pil_image in images:
                frame_np = np.array(pil_image, dtype=np.uint8)
                writer.append_data(frame_np)