hunyuan_runner.py 7.09 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import os
PengGao's avatar
PengGao committed
2

helloyongyang's avatar
helloyongyang committed
3
4
5
6
import numpy as np
import torch
import torchvision
from PIL import Image
PengGao's avatar
PengGao committed
7

helloyongyang's avatar
helloyongyang committed
8
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
PengGao's avatar
PengGao committed
9
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
helloyongyang's avatar
helloyongyang committed
10
11
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.models.networks.hunyuan.model import HunyuanModel
PengGao's avatar
PengGao committed
12
13
14
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerAdaCaching, HunyuanSchedulerCustomCaching, HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
helloyongyang's avatar
helloyongyang committed
15
16
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.utils.profiler import ProfilingContext
PengGao's avatar
PengGao committed
17
18
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_videos_grid
helloyongyang's avatar
helloyongyang committed
19
20
21
22
23
24
25


@RUNNER_REGISTER("hunyuan")
class HunyuanRunner(DefaultRunner):
    def __init__(self, config):
        super().__init__(config)

26
27
    def load_transformer(self):
        return HunyuanModel(self.config.model_path, self.config, self.init_device, self.config)
28

29
    def load_image_encoder(self):
30
        return None
helloyongyang's avatar
helloyongyang committed
31

32
    def load_text_encoder(self):
helloyongyang's avatar
helloyongyang committed
33
        if self.config.task == "t2v":
34
            text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(self.config.model_path, "text_encoder"), self.init_device)
helloyongyang's avatar
helloyongyang committed
35
        else:
36
37
            text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(self.config.model_path, "text_encoder_i2v"), self.init_device)
        text_encoder_2 = TextEncoderHFClipModel(os.path.join(self.config.model_path, "text_encoder_2"), self.init_device)
helloyongyang's avatar
helloyongyang committed
38
        text_encoders = [text_encoder_1, text_encoder_2]
39
40
        return text_encoders

41
42
    def load_vae(self):
        vae_model = VideoEncoderKLCausal3DModel(self.config.model_path, dtype=torch.float16, device=self.init_device, config=self.config)
43
        return vae_model, vae_model
helloyongyang's avatar
helloyongyang committed
44
45
46
47
48
49
50
51

    def init_scheduler(self):
        if self.config.feature_caching == "NoCaching":
            scheduler = HunyuanScheduler(self.config)
        elif self.config.feature_caching == "Tea":
            scheduler = HunyuanSchedulerTeaCaching(self.config)
        elif self.config.feature_caching == "TaylorSeer":
            scheduler = HunyuanSchedulerTaylorCaching(self.config)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
52
53
54
55
        elif self.config.feature_caching == "Ada":
            scheduler = HunyuanSchedulerAdaCaching(self.config)
        elif self.config.feature_caching == "Custom":
            scheduler = HunyuanSchedulerCustomCaching(self.config)
helloyongyang's avatar
helloyongyang committed
56
57
58
59
        else:
            raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
        self.model.set_scheduler(scheduler)

60
    def run_text_encoder(self, text, img):
helloyongyang's avatar
helloyongyang committed
61
        text_encoder_output = {}
62
63
64
        for i, encoder in enumerate(self.text_encoders):
            if self.config.task == "i2v" and i == 0:
                text_state, attention_mask = encoder.infer(text, img, self.config)
helloyongyang's avatar
helloyongyang committed
65
            else:
66
                text_state, attention_mask = encoder.infer(text, self.config)
helloyongyang's avatar
helloyongyang committed
67
68
69
70
            text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
            text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
        return text_encoder_output

71
72
    @staticmethod
    def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
helloyongyang's avatar
helloyongyang committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        aspect_ratio = float(height) / float(width)
        diff_ratios = ratios - aspect_ratio

        if aspect_ratio >= 1:
            indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
        else:
            indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]

        closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
        closest_size = buckets[closest_ratio_id]
        closest_ratio = ratios[closest_ratio_id]

        return closest_size, closest_ratio

87
88
    @staticmethod
    def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
helloyongyang's avatar
helloyongyang committed
89
90
91
92
93
94
95
96
97
98
99
100
101
        num_patches = round((base_size / patch_size) ** 2)
        assert max_ratio >= 1.0
        crop_size_list = []
        wp, hp = num_patches, 1
        while wp > 0:
            if max(wp, hp) / min(wp, hp) <= max_ratio:
                crop_size_list.append((wp * patch_size, hp * patch_size))
            if (hp + 1) * wp <= num_patches:
                hp += 1
            else:
                wp -= 1
        return crop_size_list

102
103
    def run_image_encoder(self, img):
        return None
helloyongyang's avatar
helloyongyang committed
104

105
106
107
    def run_vae_encoder(self, img):
        kwargs = {}
        if self.config.i2v_resolution == "720p":
helloyongyang's avatar
helloyongyang committed
108
            bucket_hw_base_size = 960
109
        elif self.config.i2v_resolution == "540p":
helloyongyang's avatar
helloyongyang committed
110
            bucket_hw_base_size = 720
111
        elif self.config.i2v_resolution == "360p":
helloyongyang's avatar
helloyongyang committed
112
113
            bucket_hw_base_size = 480
        else:
114
            raise ValueError(f"self.config.i2v_resolution: {self.config.i2v_resolution} must be in [360p, 540p, 720p]")
helloyongyang's avatar
helloyongyang committed
115
116
117
118
119
120
121

        origin_size = img.size

        crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
        aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
        closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)

122
123
        self.config.target_height, self.config.target_width = closest_size
        kwargs["target_height"], kwargs["target_width"] = closest_size
helloyongyang's avatar
helloyongyang committed
124
125
126
127
128
129
130
131
132
133
134

        resize_param = min(closest_size)
        center_crop_param = closest_size

        ref_image_transform = torchvision.transforms.Compose(
            [torchvision.transforms.Resize(resize_param), torchvision.transforms.CenterCrop(center_crop_param), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
        )

        semantic_image_pixel_values = [ref_image_transform(img)]
        semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))

135
        img_latents = self.vae_encoder.encode(semantic_image_pixel_values, self.config).mode()
helloyongyang's avatar
helloyongyang committed
136
137
138
139

        scaling_factor = 0.476986
        img_latents.mul_(scaling_factor)

140
141
        return img_latents, kwargs

142
143
    def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
        image_encoder_output = {"img": img, "img_latents": vae_encoder_out}
144
        return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
helloyongyang's avatar
helloyongyang committed
145
146
147
148
149
150
151
152
153
154

    def set_target_shape(self):
        vae_scale_factor = 2 ** (4 - 1)
        self.config.target_shape = (
            1,
            16,
            (self.config.target_video_length - 1) // 4 + 1,
            int(self.config.target_height) // vae_scale_factor,
            int(self.config.target_width) // vae_scale_factor,
        )
155
156
157
158
        return {"target_height": self.config.target_height, "target_width": self.config.target_width, "target_shape": self.config.target_shape}

    def save_video_func(self, images):
        save_videos_grid(images, self.config.save_video_path, fps=self.config.get("fps", 24))