__main__.py 16.3 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
import argparse
import torch
import torch.distributed as dist
import os
import time
import gc
import json
helloyongyang's avatar
helloyongyang committed
8
import torchvision
helloyongyang's avatar
helloyongyang committed
9
10
11
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image
12

13
from lightx2v.utils.envs import *
14
15
16
17
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.set_config import set_config

18
19
20
21
22
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
helloyongyang's avatar
helloyongyang committed
23

24
25
26
27
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.feature_caching.scheduler import WanSchedulerTeaCaching
helloyongyang's avatar
helloyongyang committed
28

29
30
31
from lightx2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
lijiaqi2's avatar
lijiaqi2 committed
32

33
34
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
helloyongyang's avatar
helloyongyang committed
35

36
37
38
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.runners.graph_runner import GraphRunner

39
from lightx2v.common.ops import *
lijiaqi2's avatar
lijiaqi2 committed
40
41


42
43
def load_models(config):
    if config["parallel_attn_type"]:
helloyongyang's avatar
helloyongyang committed
44
45
46
        cur_rank = dist.get_rank()  # 获取当前进程的 rank
        torch.cuda.set_device(cur_rank)  # 设置当前进程的 CUDA 设备
    image_encoder = None
47
    if config.cpu_offload:
helloyongyang's avatar
helloyongyang committed
48
49
50
51
        init_device = torch.device("cpu")
    else:
        init_device = torch.device("cuda")

52
53
54
    if config.model_cls == "hunyuan":
        if config.task == "t2v":
            text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(config.model_path, "text_encoder"), init_device)
helloyongyang's avatar
helloyongyang committed
55
        else:
56
57
            text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(config.model_path, "text_encoder_i2v"), init_device)
        text_encoder_2 = TextEncoderHFClipModel(os.path.join(config.model_path, "text_encoder_2"), init_device)
helloyongyang's avatar
helloyongyang committed
58
        text_encoders = [text_encoder_1, text_encoder_2]
59
60
        model = HunyuanModel(config.model_path, config, init_device, config)
        vae_model = VideoEncoderKLCausal3DModel(config.model_path, dtype=torch.float16, device=init_device, config=config)
helloyongyang's avatar
helloyongyang committed
61

62
    elif config.model_cls == "wan2.1":
63
        with ProfilingContext("Load Text Encoder"):
64
            text_encoder = T5EncoderModel(
65
                text_len=config["text_len"],
66
67
                dtype=torch.bfloat16,
                device=init_device,
68
69
                checkpoint_path=os.path.join(config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
                tokenizer_path=os.path.join(config.model_path, "google/umt5-xxl"),
70
71
72
                shard_fn=None,
            )
            text_encoders = [text_encoder]
73
        with ProfilingContext("Load Wan Model"):
74
            model = WanModel(config.model_path, config, init_device)
lijiaqi2's avatar
lijiaqi2 committed
75

76
        if config.lora_path:
lijiaqi2's avatar
lijiaqi2 committed
77
            lora_wrapper = WanLoraWrapper(model)
78
            with ProfilingContext("Load LoRA Model"):
79
80
                lora_name = lora_wrapper.load_lora(config.lora_path)
                lora_wrapper.apply_lora(lora_name, config.strength_model)
lijiaqi2's avatar
lijiaqi2 committed
81
82
                print(f"Loaded LoRA: {lora_name}")

83
        with ProfilingContext("Load WAN VAE Model"):
84
85
            vae_model = WanVAE(vae_pth=os.path.join(config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=config.parallel_vae)
        if config.task == "i2v":
86
            with ProfilingContext("Load Image Encoder"):
lijiaqi2's avatar
lijiaqi2 committed
87
88
89
                image_encoder = CLIPModel(
                    dtype=torch.float16,
                    device=init_device,
90
91
                    checkpoint_path=os.path.join(config.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
                    tokenizer_path=os.path.join(config.model_path, "xlm-roberta-large"),
lijiaqi2's avatar
lijiaqi2 committed
92
                )
helloyongyang's avatar
helloyongyang committed
93
    else:
94
        raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
helloyongyang's avatar
helloyongyang committed
95
96
97
98

    return model, text_encoders, vae_model, image_encoder


99
100
101
def set_target_shape(config, image_encoder_output):
    if config.model_cls == "hunyuan":
        if config.task == "t2v":
helloyongyang's avatar
helloyongyang committed
102
            vae_scale_factor = 2 ** (4 - 1)
103
            config.target_shape = (
helloyongyang's avatar
helloyongyang committed
104
105
                1,
                16,
106
107
108
                (config.target_video_length - 1) // 4 + 1,
                int(config.target_height) // vae_scale_factor,
                int(config.target_width) // vae_scale_factor,
helloyongyang's avatar
helloyongyang committed
109
            )
110
        elif config.task == "i2v":
helloyongyang's avatar
helloyongyang committed
111
            vae_scale_factor = 2 ** (4 - 1)
112
            config.target_shape = (
helloyongyang's avatar
helloyongyang committed
113
114
                1,
                16,
115
                (config.target_video_length - 1) // 4 + 1,
helloyongyang's avatar
helloyongyang committed
116
117
118
                int(image_encoder_output["target_height"]) // vae_scale_factor,
                int(image_encoder_output["target_width"]) // vae_scale_factor,
            )
119
120
121
122
123
    elif config.model_cls == "wan2.1":
        if config.task == "i2v":
            config.target_shape = (16, 21, config.lat_h, config.lat_w)
        elif config.task == "t2v":
            config.target_shape = (
helloyongyang's avatar
helloyongyang committed
124
                16,
125
126
127
                (config.target_video_length - 1) // 4 + 1,
                int(config.target_height) // config.vae_stride[1],
                int(config.target_width) // config.vae_stride[2],
helloyongyang's avatar
helloyongyang committed
128
129
130
            )


helloyongyang's avatar
helloyongyang committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
    num_patches = round((base_size / patch_size) ** 2)
    assert max_ratio >= 1.0
    crop_size_list = []
    wp, hp = num_patches, 1
    while wp > 0:
        if max(wp, hp) / min(wp, hp) <= max_ratio:
            crop_size_list.append((wp * patch_size, hp * patch_size))
        if (hp + 1) * wp <= num_patches:
            hp += 1
        else:
            wp -= 1
    return crop_size_list


def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
    aspect_ratio = float(height) / float(width)
    diff_ratios = ratios - aspect_ratio

    if aspect_ratio >= 1:
        indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
    else:
        indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]

    closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
    closest_size = buckets[closest_ratio_id]
    closest_ratio = ratios[closest_ratio_id]

    return closest_size, closest_ratio


162
163
164
def run_image_encoder(config, image_encoder, vae_model):
    if config.model_cls == "hunyuan":
        img = Image.open(config.image_path).convert("RGB")
helloyongyang's avatar
helloyongyang committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        origin_size = img.size

        i2v_resolution = "720p"
        if i2v_resolution == "720p":
            bucket_hw_base_size = 960
        elif i2v_resolution == "540p":
            bucket_hw_base_size = 720
        elif i2v_resolution == "360p":
            bucket_hw_base_size = 480
        else:
            raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")

        crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
        aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
        closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)

        resize_param = min(closest_size)
        center_crop_param = closest_size

        ref_image_transform = torchvision.transforms.Compose(
            [torchvision.transforms.Resize(resize_param), torchvision.transforms.CenterCrop(center_crop_param), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
        )

        semantic_image_pixel_values = [ref_image_transform(img)]
        semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))

191
        img_latents = vae_model.encode(semantic_image_pixel_values, config).mode()
helloyongyang's avatar
helloyongyang committed
192
193
194
195
196
197
198
199

        scaling_factor = 0.476986
        img_latents.mul_(scaling_factor)

        target_height, target_width = closest_size

        return {"img": img, "img_latents": img_latents, "target_height": target_height, "target_width": target_width}

200
201
    elif config.model_cls == "wan2.1":
        img = Image.open(config.image_path).convert("RGB")
helloyongyang's avatar
helloyongyang committed
202
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
203
        clip_encoder_out = image_encoder.visual([img[:, None, :, :]], config).squeeze(0).to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
204
205
        h, w = img.shape[1:]
        aspect_ratio = h / w
206
207
208
209
210
        max_area = config.target_height * config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
        h = lat_h * config.vae_stride[1]
        w = lat_w * config.vae_stride[2]
helloyongyang's avatar
helloyongyang committed
211

212
213
        config.lat_h = lat_h
        config.lat_w = lat_w
Dongz's avatar
Dongz committed
214
215

        msk = torch.ones(1, 81, lat_h, lat_w, device=torch.device("cuda"))
helloyongyang's avatar
helloyongyang committed
216
        msk[:, 1:] = 0
Dongz's avatar
Dongz committed
217
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
helloyongyang's avatar
helloyongyang committed
218
219
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
220
        vae_encode_out = vae_model.encode(
221
            [torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.zeros(3, 80, h, w)], dim=1).cuda()], config
gushiqiao's avatar
gushiqiao committed
222
        )[0]
helloyongyang's avatar
helloyongyang committed
223
224
225
226
        vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
        return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}

    else:
227
        raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
helloyongyang's avatar
helloyongyang committed
228
229


230
def run_text_encoder(text, text_encoders, config, image_encoder_output):
helloyongyang's avatar
helloyongyang committed
231
    text_encoder_output = {}
232
    if config.model_cls == "hunyuan":
helloyongyang's avatar
helloyongyang committed
233
        for i, encoder in enumerate(text_encoders):
234
235
            if config.task == "i2v" and i == 0:
                text_state, attention_mask = encoder.infer(text, image_encoder_output["img"], config)
helloyongyang's avatar
helloyongyang committed
236
            else:
237
                text_state, attention_mask = encoder.infer(text, config)
Dongz's avatar
Dongz committed
238
239
            text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
            text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
helloyongyang's avatar
helloyongyang committed
240

241
242
243
244
    elif config.model_cls == "wan2.1":
        n_prompt = config.get("sample_neg_prompt", "")
        context = text_encoders[0].infer([text], config)
        context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], config)
helloyongyang's avatar
helloyongyang committed
245
246
247
248
        text_encoder_output["context"] = context
        text_encoder_output["context_null"] = context_null

    else:
249
        raise NotImplementedError(f"Unsupported model type: {config.model_cls}")
helloyongyang's avatar
helloyongyang committed
250
251
252
253

    return text_encoder_output


254
255
256
257
258
259
260
261
def init_scheduler(config, image_encoder_output):
    if config.model_cls == "hunyuan":
        if config.feature_caching == "NoCaching":
            scheduler = HunyuanScheduler(config, image_encoder_output)
        elif config.feature_caching == "Tea":
            scheduler = HunyuanSchedulerTeaCaching(config, image_encoder_output)
        elif config.feature_caching == "TaylorSeer":
            scheduler = HunyuanSchedulerTaylorCaching(config, image_encoder_output)
helloyongyang's avatar
helloyongyang committed
262
        else:
263
            raise NotImplementedError(f"Unsupported feature_caching type: {config.feature_caching}")
helloyongyang's avatar
helloyongyang committed
264

265
266
267
268
269
    elif config.model_cls == "wan2.1":
        if config.feature_caching == "NoCaching":
            scheduler = WanScheduler(config)
        elif config.feature_caching == "Tea":
            scheduler = WanSchedulerTeaCaching(config)
helloyongyang's avatar
helloyongyang committed
270
        else:
271
            raise NotImplementedError(f"Unsupported feature_caching type: {config.feature_caching}")
helloyongyang's avatar
helloyongyang committed
272
273

    else:
274
        raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
helloyongyang's avatar
helloyongyang committed
275
276
277
    return scheduler


278
279
def run_vae(latents, generator, config):
    images = vae_model.decode(latents, generator=generator, config=config)
helloyongyang's avatar
helloyongyang committed
280
281
282
283
284
285
286
287
288
289
    return images


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan"], default="hunyuan")
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--config_path", type=str, default=None)
    parser.add_argument("--image_path", type=str, default=None)
Dongz's avatar
Dongz committed
290
    parser.add_argument("--save_video_path", type=str, default="./output_ligthx2v.mp4")
helloyongyang's avatar
helloyongyang committed
291
292
293
294
295
296
297
298
299
    parser.add_argument("--prompt", type=str, required=True)
    parser.add_argument("--infer_steps", type=int, required=True)
    parser.add_argument("--target_video_length", type=int, required=True)
    parser.add_argument("--target_width", type=int, required=True)
    parser.add_argument("--target_height", type=int, required=True)
    parser.add_argument("--attention_type", type=str, required=True)
    parser.add_argument("--sample_neg_prompt", type=str, default="")
    parser.add_argument("--sample_guide_scale", type=float, default=5.0)
    parser.add_argument("--sample_shift", type=float, default=5.0)
Dongz's avatar
Dongz committed
300
301
302
303
304
    parser.add_argument("--do_mm_calib", action="store_true")
    parser.add_argument("--cpu_offload", action="store_true")
    parser.add_argument("--feature_caching", choices=["NoCaching", "TaylorSeer", "Tea"], default="NoCaching")
    parser.add_argument("--mm_config", default=None)
    parser.add_argument("--seed", type=int, default=42)
Xinchi Huang's avatar
Xinchi Huang committed
305
    parser.add_argument("--parallel_attn_type", default=None, choices=["ulysses", "ring"])
Dongz's avatar
Dongz committed
306
307
308
309
    parser.add_argument("--parallel_vae", action="store_true")
    parser.add_argument("--max_area", action="store_true")
    parser.add_argument("--vae_stride", default=(4, 8, 8))
    parser.add_argument("--patch_size", default=(1, 2, 2))
helloyongyang's avatar
helloyongyang committed
310
311
    parser.add_argument("--teacache_thresh", type=float, default=0.26)
    parser.add_argument("--use_ret_steps", action="store_true", default=False)
lijiaqi2's avatar
lijiaqi2 committed
312
313
314
    parser.add_argument("--use_bfloat16", action="store_true", default=True)
    parser.add_argument("--lora_path", type=str, default=None)
    parser.add_argument("--strength_model", type=float, default=1.0)
helloyongyang's avatar
helloyongyang committed
315
316
    args = parser.parse_args()

317
    start_time = time.perf_counter()
helloyongyang's avatar
helloyongyang committed
318
    print(f"args: {args}")
Dongz's avatar
Dongz committed
319

helloyongyang's avatar
helloyongyang committed
320
321
    seed_all(args.seed)

322
323
324
    config = set_config(args)

    if config.parallel_attn_type:
Dongz's avatar
Dongz committed
325
        dist.init_process_group(backend="nccl")
helloyongyang's avatar
helloyongyang committed
326

327
    print(f"config: {config}")
helloyongyang's avatar
helloyongyang committed
328

329
    with ProfilingContext("Load models"):
330
        model, text_encoders, vae_model, image_encoder = load_models(config)
331

332
333
    if config["task"] in ["i2v"]:
        image_encoder_output = run_image_encoder(config, image_encoder, vae_model)
helloyongyang's avatar
helloyongyang committed
334
335
336
    else:
        image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}

337
    with ProfilingContext("Run Text Encoder"):
338
339
340
        text_encoder_output = run_text_encoder(config["prompt"], text_encoders, config, image_encoder_output)

    inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
helloyongyang's avatar
helloyongyang committed
341

342
343
    set_target_shape(config, image_encoder_output)
    scheduler = init_scheduler(config, image_encoder_output)
helloyongyang's avatar
helloyongyang committed
344
345
346
347
348

    model.set_scheduler(scheduler)

    gc.collect()
    torch.cuda.empty_cache()
349

helloyongyang's avatar
helloyongyang committed
350
    if CHECK_ENABLE_GRAPH_MODE():
351
352
353
354
355
356
        default_runner = DefaultRunner(model, inputs)
        runner = GraphRunner(default_runner)
    else:
        runner = DefaultRunner(model, inputs)

    latents, generator = runner.run()
helloyongyang's avatar
helloyongyang committed
357

358
    if config.cpu_offload:
359
360
361
        scheduler.clear()
        del text_encoder_output, image_encoder_output, model, text_encoders, scheduler
        torch.cuda.empty_cache()
helloyongyang's avatar
helloyongyang committed
362

363
    with ProfilingContext("Run VAE"):
364
        images = run_vae(latents, generator, config)
helloyongyang's avatar
helloyongyang committed
365

366
    if not config.parallel_attn_type or (config.parallel_attn_type and dist.get_rank() == 0):
367
        with ProfilingContext("Save video"):
368
369
            if config.model_cls == "wan2.1":
                cache_video(tensor=images, save_file=config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
370
            else:
371
                save_videos_grid(images, config.save_video_path, fps=24)
helloyongyang's avatar
helloyongyang committed
372

373
    end_time = time.perf_counter()
374
    print(f"Total cost: {end_time - start_time}")