cogvidex_runner.py 3.48 KB
Newer Older
Watebear's avatar
Watebear committed
1
2
import imageio
import numpy as np
PengGao's avatar
PengGao committed
3
from diffusers.utils import export_to_video
Watebear's avatar
Watebear committed
4
5
6

from lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model import T5EncoderModel_v1_1_xxl
from lightx2v.models.networks.cogvideox.model import CogvideoxModel
PengGao's avatar
PengGao committed
7
from lightx2v.models.runners.default_runner import DefaultRunner
Watebear's avatar
Watebear committed
8
from lightx2v.models.schedulers.cogvideox.scheduler import CogvideoxXDPMScheduler
PengGao's avatar
PengGao committed
9
10
11
from lightx2v.models.video_encoders.hf.cogvideox.model import CogvideoxVAE
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
Watebear's avatar
Watebear committed
12
13
14
15
16
17
18


@RUNNER_REGISTER("cogvideox")
class CogvideoxRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
19
    def load_transformer(self):
20
21
22
        model = CogvideoxModel(self.config)
        return model

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
23
    def load_image_encoder(self):
24
25
        return None

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
26
    def load_text_encoder(self):
Watebear's avatar
Watebear committed
27
28
        text_encoder = T5EncoderModel_v1_1_xxl(self.config)
        text_encoders = [text_encoder]
29
30
        return text_encoders

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
31
    def load_vae(self):
Watebear's avatar
Watebear committed
32
        vae_model = CogvideoxVAE(self.config)
33
        return vae_model, vae_model
Watebear's avatar
Watebear committed
34
35
36
37
38

    def init_scheduler(self):
        scheduler = CogvideoxXDPMScheduler(self.config)
        self.model.set_scheduler(scheduler)

39
    def run_text_encoder(self, text, img):
Watebear's avatar
Watebear committed
40
        text_encoder_output = {}
41
42
43
        n_prompt = self.config.get("negative_prompt", "")
        context = self.text_encoders[0].infer([text], self.config)
        context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""], self.config)
Watebear's avatar
Watebear committed
44
45
46
47
        text_encoder_output["context"] = context
        text_encoder_output["context_null"] = context_null
        return text_encoder_output

48
49
50
51
    def run_vae_encoder(self, img):
        # TODO: implement vae encoder for Cogvideox
        raise NotImplementedError("I2V inference is not implemented for Cogvideox.")

52
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
53
54
55
        # TODO: Implement image encoder for Cogvideox-I2V
        raise ValueError(f"Unsupported model class: {self.config['model_cls']}")

Watebear's avatar
Watebear committed
56
    def set_target_shape(self):
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        ret = {}
        if self.config.task == "i2v":
            # TODO: implement set_target_shape for Cogvideox-I2V
            raise NotImplementedError("I2V inference is not implemented for Cogvideox.")
        else:
            num_frames = self.config.target_video_length
            latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1
            additional_frames = 0
            patch_size_t = self.config.patch_size_t
            if patch_size_t is not None and latent_frames % patch_size_t != 0:
                additional_frames = patch_size_t - latent_frames % patch_size_t
                num_frames += additional_frames * self.config.vae_scale_factor_temporal
            self.config.target_shape = (
                self.config.batch_size,
                (num_frames - 1) // self.config.vae_scale_factor_temporal + 1,
                self.config.latent_channels,
                self.config.height // self.config.vae_scale_factor_spatial,
                self.config.width // self.config.vae_scale_factor_spatial,
            )
            ret["target_shape"] = self.config.target_shape
        return ret

    def save_video_func(self, images):
Watebear's avatar
Watebear committed
80
81
82
83
        with imageio.get_writer(self.config.save_video_path, fps=16) as writer:
            for pil_image in images:
                frame_np = np.array(pil_image, dtype=np.uint8)
                writer.append_data(frame_np)