hunyuan_runner.py 7.04 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import os
PengGao's avatar
PengGao committed
2

helloyongyang's avatar
helloyongyang committed
3
4
5
import numpy as np
import torch
import torchvision
PengGao's avatar
PengGao committed
6

helloyongyang's avatar
helloyongyang committed
7
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
PengGao's avatar
PengGao committed
8
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
helloyongyang's avatar
helloyongyang committed
9
10
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.models.networks.hunyuan.model import HunyuanModel
PengGao's avatar
PengGao committed
11
12
13
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerAdaCaching, HunyuanSchedulerCustomCaching, HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
helloyongyang's avatar
helloyongyang committed
14
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
15
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
16
17
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_videos_grid
helloyongyang's avatar
helloyongyang committed
18
19
20
21
22
23
24


@RUNNER_REGISTER("hunyuan")
class HunyuanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

25
26
    def load_transformer(self):
        return HunyuanModel(self.config.model_path, self.config, self.init_device, self.config)
27

28
    def load_image_encoder(self):
29
        return None
helloyongyang's avatar
helloyongyang committed
30

31
    def load_text_encoder(self):
helloyongyang's avatar
helloyongyang committed
32
        if self.config.task == "t2v":
33
            text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(self.config.model_path, "text_encoder"), self.init_device)
helloyongyang's avatar
helloyongyang committed
34
        else:
35
36
            text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(self.config.model_path, "text_encoder_i2v"), self.init_device)
        text_encoder_2 = TextEncoderHFClipModel(os.path.join(self.config.model_path, "text_encoder_2"), self.init_device)
helloyongyang's avatar
helloyongyang committed
37
        text_encoders = [text_encoder_1, text_encoder_2]
38
39
        return text_encoders

40
41
    def load_vae(self):
        vae_model = VideoEncoderKLCausal3DModel(self.config.model_path, dtype=torch.float16, device=self.init_device, config=self.config)
42
        return vae_model, vae_model
helloyongyang's avatar
helloyongyang committed
43
44
45
46
47
48
49
50

    def init_scheduler(self):
        if self.config.feature_caching == "NoCaching":
            scheduler = HunyuanScheduler(self.config)
        elif self.config.feature_caching == "Tea":
            scheduler = HunyuanSchedulerTeaCaching(self.config)
        elif self.config.feature_caching == "TaylorSeer":
            scheduler = HunyuanSchedulerTaylorCaching(self.config)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
51
52
53
54
        elif self.config.feature_caching == "Ada":
            scheduler = HunyuanSchedulerAdaCaching(self.config)
        elif self.config.feature_caching == "Custom":
            scheduler = HunyuanSchedulerCustomCaching(self.config)
helloyongyang's avatar
helloyongyang committed
55
56
57
58
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
        self.model.set_scheduler(scheduler)

59
    def run_text_encoder(self, text, img):
helloyongyang's avatar
helloyongyang committed
60
        text_encoder_output = {}
61
62
63
        for i, encoder in enumerate(self.text_encoders):
            if self.config.task == "i2v" and i == 0:
                text_state, attention_mask = encoder.infer(text, img, self.config)
helloyongyang's avatar
helloyongyang committed
64
            else:
65
                text_state, attention_mask = encoder.infer(text, self.config)
66
            text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=GET_DTYPE())
helloyongyang's avatar
helloyongyang committed
67
68
69
            text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
        return text_encoder_output

70
71
    @staticmethod
    def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
helloyongyang's avatar
helloyongyang committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        aspect_ratio = float(height) / float(width)
        diff_ratios = ratios - aspect_ratio

        if aspect_ratio >= 1:
            indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
        else:
            indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]

        closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
        closest_size = buckets[closest_ratio_id]
        closest_ratio = ratios[closest_ratio_id]

        return closest_size, closest_ratio

86
87
    @staticmethod
    def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
helloyongyang's avatar
helloyongyang committed
88
89
90
91
92
93
94
95
96
97
98
99
100
        num_patches = round((base_size / patch_size) ** 2)
        assert max_ratio >= 1.0
        crop_size_list = []
        wp, hp = num_patches, 1
        while wp > 0:
            if max(wp, hp) / min(wp, hp) <= max_ratio:
                crop_size_list.append((wp * patch_size, hp * patch_size))
            if (hp + 1) * wp <= num_patches:
                hp += 1
            else:
                wp -= 1
        return crop_size_list

101
102
    def run_image_encoder(self, img):
        return None
helloyongyang's avatar
helloyongyang committed
103

104
105
106
    def run_vae_encoder(self, img):
        kwargs = {}
        if self.config.i2v_resolution == "720p":
helloyongyang's avatar
helloyongyang committed
107
            bucket_hw_base_size = 960
108
        elif self.config.i2v_resolution == "540p":
helloyongyang's avatar
helloyongyang committed
109
            bucket_hw_base_size = 720
110
        elif self.config.i2v_resolution == "360p":
helloyongyang's avatar
helloyongyang committed
111
112
            bucket_hw_base_size = 480
        else:
113
            raise ValueError(f"self.config.i2v_resolution: {self.config.i2v_resolution} must be in [360p, 540p, 720p]")
helloyongyang's avatar
helloyongyang committed
114
115
116
117
118
119
120

        origin_size = img.size

        crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
        aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
        closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)

121
122
        self.config.target_height, self.config.target_width = closest_size
        kwargs["target_height"], kwargs["target_width"] = closest_size
helloyongyang's avatar
helloyongyang committed
123
124
125
126
127
128
129
130
131
132
133

        resize_param = min(closest_size)
        center_crop_param = closest_size

        ref_image_transform = torchvision.transforms.Compose(
            [torchvision.transforms.Resize(resize_param), torchvision.transforms.CenterCrop(center_crop_param), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
        )

        semantic_image_pixel_values = [ref_image_transform(img)]
        semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))

134
        img_latents = self.vae_encoder.encode(semantic_image_pixel_values, self.config).mode()
helloyongyang's avatar
helloyongyang committed
135
136
137
138

        scaling_factor = 0.476986
        img_latents.mul_(scaling_factor)

139
140
        return img_latents, kwargs

141
142
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
        image_encoder_output = {"img": img, "img_latents": vae_encoder_out}
143
        return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
helloyongyang's avatar
helloyongyang committed
144
145
146
147
148
149
150
151
152
153

    def set_target_shape(self):
        vae_scale_factor = 2 ** (4 - 1)
        self.config.target_shape = (
            1,
            16,
            (self.config.target_video_length - 1) // 4 + 1,
            int(self.config.target_height) // vae_scale_factor,
            int(self.config.target_width) // vae_scale_factor,
        )
154
155
156
157
        return {"target_height": self.config.target_height, "target_width": self.config.target_width, "target_shape": self.config.target_shape}

    def save_video_func(self, images):
        save_videos_grid(images, self.config.save_video_path, fps=self.config.get("fps", 24))