cogvidex_runner.py 3.05 KB
Newer Older
Watebear's avatar
Watebear committed
1
2
from lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model import T5EncoderModel_v1_1_xxl
from lightx2v.models.networks.cogvideox.model import CogvideoxModel
PengGao's avatar
PengGao committed
3
from lightx2v.models.runners.default_runner import DefaultRunner
Watebear's avatar
Watebear committed
4
from lightx2v.models.schedulers.cogvideox.scheduler import CogvideoxXDPMScheduler
PengGao's avatar
PengGao committed
5
6
from lightx2v.models.video_encoders.hf.cogvideox.model import CogvideoxVAE
from lightx2v.utils.registry_factory import RUNNER_REGISTER
Watebear's avatar
Watebear committed
7
8
9
10
11
12
13


@RUNNER_REGISTER("cogvideox")
class CogvideoxRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
14
    def load_transformer(self):
15
16
17
        model = CogvideoxModel(self.config)
        return model

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
18
    def load_image_encoder(self):
19
20
        return None

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
21
    def load_text_encoder(self):
Watebear's avatar
Watebear committed
22
23
        text_encoder = T5EncoderModel_v1_1_xxl(self.config)
        text_encoders = [text_encoder]
24
25
        return text_encoders

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
26
    def load_vae(self):
Watebear's avatar
Watebear committed
27
        vae_model = CogvideoxVAE(self.config)
28
        return vae_model, vae_model
Watebear's avatar
Watebear committed
29
30

    def init_scheduler(self):
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
31
        self.scheduler = CogvideoxXDPMScheduler(self.config)
Watebear's avatar
Watebear committed
32

33
    def run_text_encoder(self, text, img):
Watebear's avatar
Watebear committed
34
        text_encoder_output = {}
35
36
37
        n_prompt = self.config.get("negative_prompt", "")
        context = self.text_encoders[0].infer([text], self.config)
        context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""], self.config)
Watebear's avatar
Watebear committed
38
39
40
41
        text_encoder_output["context"] = context
        text_encoder_output["context_null"] = context_null
        return text_encoder_output

42
43
44
45
    def run_vae_encoder(self, img):
        # TODO: implement vae encoder for Cogvideox
        raise NotImplementedError("I2V inference is not implemented for Cogvideox.")

46
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
47
48
49
        # TODO: Implement image encoder for Cogvideox-I2V
        raise ValueError(f"Unsupported model class: {self.config['model_cls']}")

Watebear's avatar
Watebear committed
50
    def set_target_shape(self):
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        ret = {}
        if self.config.task == "i2v":
            # TODO: implement set_target_shape for Cogvideox-I2V
            raise NotImplementedError("I2V inference is not implemented for Cogvideox.")
        else:
            num_frames = self.config.target_video_length
            latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1
            additional_frames = 0
            patch_size_t = self.config.patch_size_t
            if patch_size_t is not None and latent_frames % patch_size_t != 0:
                additional_frames = patch_size_t - latent_frames % patch_size_t
                num_frames += additional_frames * self.config.vae_scale_factor_temporal
            self.config.target_shape = (
                self.config.batch_size,
                (num_frames - 1) // self.config.vae_scale_factor_temporal + 1,
                self.config.latent_channels,
                self.config.height // self.config.vae_scale_factor_spatial,
                self.config.width // self.config.vae_scale_factor_spatial,
            )
            ret["target_shape"] = self.config.target_shape
        return ret