__main__.py 16.9 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
import argparse
import torch
import torch.distributed as dist
import os
import time
import gc
import json
helloyongyang's avatar
helloyongyang committed
8
import torchvision
helloyongyang's avatar
helloyongyang committed
9
10
import torchvision.transforms.functional as TF
import numpy as np
11
from contextlib import contextmanager
helloyongyang's avatar
helloyongyang committed
12
from PIL import Image
13
14
15
16
17
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
helloyongyang's avatar
helloyongyang committed
18

19
20
21
22
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.feature_caching.scheduler import WanSchedulerTeaCaching
helloyongyang's avatar
helloyongyang committed
23

24
25
26
from lightx2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
lijiaqi2's avatar
lijiaqi2 committed
27

28
29
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
helloyongyang's avatar
helloyongyang committed
30
31
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
from lightx2v.common.ops import *
32
from lightx2v.utils.set_config import set_config
helloyongyang's avatar
helloyongyang committed
33
34


lijiaqi2's avatar
lijiaqi2 committed
35
36
@contextmanager
def time_duration(label: str = ""):
37
    torch.cuda.synchronize()
lijiaqi2's avatar
lijiaqi2 committed
38
39
    start_time = time.time()
    yield
40
    torch.cuda.synchronize()
lijiaqi2's avatar
lijiaqi2 committed
41
42
43
44
    end_time = time.time()
    print(f"==> {label} start:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))} cost {end_time - start_time:.2f} seconds")


45
46
def load_models(config):
    if config["parallel_attn_type"]:
helloyongyang's avatar
helloyongyang committed
47
48
49
        cur_rank = dist.get_rank()  # 获取当前进程的 rank
        torch.cuda.set_device(cur_rank)  # 设置当前进程的 CUDA 设备
    image_encoder = None
50
    if config.cpu_offload:
helloyongyang's avatar
helloyongyang committed
51
52
53
54
        init_device = torch.device("cpu")
    else:
        init_device = torch.device("cuda")

55
56
57
    if config.model_cls == "hunyuan":
        if config.task == "t2v":
            text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(config.model_path, "text_encoder"), init_device)
helloyongyang's avatar
helloyongyang committed
58
        else:
59
60
            text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(config.model_path, "text_encoder_i2v"), init_device)
        text_encoder_2 = TextEncoderHFClipModel(os.path.join(config.model_path, "text_encoder_2"), init_device)
helloyongyang's avatar
helloyongyang committed
61
        text_encoders = [text_encoder_1, text_encoder_2]
62
63
        model = HunyuanModel(config.model_path, config, init_device, config)
        vae_model = VideoEncoderKLCausal3DModel(config.model_path, dtype=torch.float16, device=init_device, config=config)
helloyongyang's avatar
helloyongyang committed
64

65
    elif config.model_cls == "wan2.1":
66
67
        with time_duration("Load Text Encoder"):
            text_encoder = T5EncoderModel(
68
                text_len=config["text_len"],
69
70
                dtype=torch.bfloat16,
                device=init_device,
71
72
                checkpoint_path=os.path.join(config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
                tokenizer_path=os.path.join(config.model_path, "google/umt5-xxl"),
73
74
75
                shard_fn=None,
            )
            text_encoders = [text_encoder]
lijiaqi2's avatar
lijiaqi2 committed
76
        with time_duration("Load Wan Model"):
77
            model = WanModel(config.model_path, config, init_device)
lijiaqi2's avatar
lijiaqi2 committed
78

79
        if config.lora_path:
lijiaqi2's avatar
lijiaqi2 committed
80
81
            lora_wrapper = WanLoraWrapper(model)
            with time_duration("Load LoRA Model"):
82
83
                lora_name = lora_wrapper.load_lora(config.lora_path)
                lora_wrapper.apply_lora(lora_name, config.strength_model)
lijiaqi2's avatar
lijiaqi2 committed
84
85
86
                print(f"Loaded LoRA: {lora_name}")

        with time_duration("Load WAN VAE Model"):
87
88
            vae_model = WanVAE(vae_pth=os.path.join(config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=config.parallel_vae)
        if config.task == "i2v":
lijiaqi2's avatar
lijiaqi2 committed
89
90
91
92
            with time_duration("Load Image Encoder"):
                image_encoder = CLIPModel(
                    dtype=torch.float16,
                    device=init_device,
93
94
                    checkpoint_path=os.path.join(config.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
                    tokenizer_path=os.path.join(config.model_path, "xlm-roberta-large"),
lijiaqi2's avatar
lijiaqi2 committed
95
                )
helloyongyang's avatar
helloyongyang committed
96
    else:
97
        raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
helloyongyang's avatar
helloyongyang committed
98
99
100
101

    return model, text_encoders, vae_model, image_encoder


102
103
104
def set_target_shape(config, image_encoder_output):
    if config.model_cls == "hunyuan":
        if config.task == "t2v":
helloyongyang's avatar
helloyongyang committed
105
            vae_scale_factor = 2 ** (4 - 1)
106
            config.target_shape = (
helloyongyang's avatar
helloyongyang committed
107
108
                1,
                16,
109
110
111
                (config.target_video_length - 1) // 4 + 1,
                int(config.target_height) // vae_scale_factor,
                int(config.target_width) // vae_scale_factor,
helloyongyang's avatar
helloyongyang committed
112
            )
113
        elif config.task == "i2v":
helloyongyang's avatar
helloyongyang committed
114
            vae_scale_factor = 2 ** (4 - 1)
115
            config.target_shape = (
helloyongyang's avatar
helloyongyang committed
116
117
                1,
                16,
118
                (config.target_video_length - 1) // 4 + 1,
helloyongyang's avatar
helloyongyang committed
119
120
121
                int(image_encoder_output["target_height"]) // vae_scale_factor,
                int(image_encoder_output["target_width"]) // vae_scale_factor,
            )
122
123
124
125
126
    elif config.model_cls == "wan2.1":
        if config.task == "i2v":
            config.target_shape = (16, 21, config.lat_h, config.lat_w)
        elif config.task == "t2v":
            config.target_shape = (
helloyongyang's avatar
helloyongyang committed
127
                16,
128
129
130
                (config.target_video_length - 1) // 4 + 1,
                int(config.target_height) // config.vae_stride[1],
                int(config.target_width) // config.vae_stride[2],
helloyongyang's avatar
helloyongyang committed
131
132
133
            )


helloyongyang's avatar
helloyongyang committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
    num_patches = round((base_size / patch_size) ** 2)
    assert max_ratio >= 1.0
    crop_size_list = []
    wp, hp = num_patches, 1
    while wp > 0:
        if max(wp, hp) / min(wp, hp) <= max_ratio:
            crop_size_list.append((wp * patch_size, hp * patch_size))
        if (hp + 1) * wp <= num_patches:
            hp += 1
        else:
            wp -= 1
    return crop_size_list


def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
    aspect_ratio = float(height) / float(width)
    diff_ratios = ratios - aspect_ratio

    if aspect_ratio >= 1:
        indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
    else:
        indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]

    closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
    closest_size = buckets[closest_ratio_id]
    closest_ratio = ratios[closest_ratio_id]

    return closest_size, closest_ratio


165
166
167
def run_image_encoder(config, image_encoder, vae_model):
    if config.model_cls == "hunyuan":
        img = Image.open(config.image_path).convert("RGB")
helloyongyang's avatar
helloyongyang committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        origin_size = img.size

        i2v_resolution = "720p"
        if i2v_resolution == "720p":
            bucket_hw_base_size = 960
        elif i2v_resolution == "540p":
            bucket_hw_base_size = 720
        elif i2v_resolution == "360p":
            bucket_hw_base_size = 480
        else:
            raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")

        crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
        aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
        closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)

        resize_param = min(closest_size)
        center_crop_param = closest_size

        ref_image_transform = torchvision.transforms.Compose(
            [torchvision.transforms.Resize(resize_param), torchvision.transforms.CenterCrop(center_crop_param), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
        )

        semantic_image_pixel_values = [ref_image_transform(img)]
        semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))

194
        img_latents = vae_model.encode(semantic_image_pixel_values, config).mode()
helloyongyang's avatar
helloyongyang committed
195
196
197
198
199
200
201
202

        scaling_factor = 0.476986
        img_latents.mul_(scaling_factor)

        target_height, target_width = closest_size

        return {"img": img, "img_latents": img_latents, "target_height": target_height, "target_width": target_width}

203
204
    elif config.model_cls == "wan2.1":
        img = Image.open(config.image_path).convert("RGB")
helloyongyang's avatar
helloyongyang committed
205
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
206
        clip_encoder_out = image_encoder.visual([img[:, None, :, :]], config).squeeze(0).to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
207
208
        h, w = img.shape[1:]
        aspect_ratio = h / w
209
210
211
212
213
        max_area = config.target_height * config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
        h = lat_h * config.vae_stride[1]
        w = lat_w * config.vae_stride[2]
helloyongyang's avatar
helloyongyang committed
214

215
216
        config.lat_h = lat_h
        config.lat_w = lat_w
Dongz's avatar
Dongz committed
217
218

        msk = torch.ones(1, 81, lat_h, lat_w, device=torch.device("cuda"))
helloyongyang's avatar
helloyongyang committed
219
        msk[:, 1:] = 0
Dongz's avatar
Dongz committed
220
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
helloyongyang's avatar
helloyongyang committed
221
222
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
223
        vae_encode_out = vae_model.encode(
224
            [torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.zeros(3, 80, h, w)], dim=1).cuda()], config
gushiqiao's avatar
gushiqiao committed
225
        )[0]
helloyongyang's avatar
helloyongyang committed
226
227
228
229
        vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
        return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}

    else:
230
        raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
helloyongyang's avatar
helloyongyang committed
231
232


233
def run_text_encoder(text, text_encoders, config, image_encoder_output):
helloyongyang's avatar
helloyongyang committed
234
    text_encoder_output = {}
235
    if config.model_cls == "hunyuan":
helloyongyang's avatar
helloyongyang committed
236
        for i, encoder in enumerate(text_encoders):
237
238
            if config.task == "i2v" and i == 0:
                text_state, attention_mask = encoder.infer(text, image_encoder_output["img"], config)
helloyongyang's avatar
helloyongyang committed
239
            else:
240
                text_state, attention_mask = encoder.infer(text, config)
Dongz's avatar
Dongz committed
241
242
            text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
            text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
helloyongyang's avatar
helloyongyang committed
243

244
245
246
247
    elif config.model_cls == "wan2.1":
        n_prompt = config.get("sample_neg_prompt", "")
        context = text_encoders[0].infer([text], config)
        context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], config)
helloyongyang's avatar
helloyongyang committed
248
249
250
251
        text_encoder_output["context"] = context
        text_encoder_output["context_null"] = context_null

    else:
252
        raise NotImplementedError(f"Unsupported model type: {config.model_cls}")
helloyongyang's avatar
helloyongyang committed
253
254
255
256

    return text_encoder_output


257
258
259
260
261
262
263
264
def init_scheduler(config, image_encoder_output):
    if config.model_cls == "hunyuan":
        if config.feature_caching == "NoCaching":
            scheduler = HunyuanScheduler(config, image_encoder_output)
        elif config.feature_caching == "Tea":
            scheduler = HunyuanSchedulerTeaCaching(config, image_encoder_output)
        elif config.feature_caching == "TaylorSeer":
            scheduler = HunyuanSchedulerTaylorCaching(config, image_encoder_output)
helloyongyang's avatar
helloyongyang committed
265
        else:
266
            raise NotImplementedError(f"Unsupported feature_caching type: {config.feature_caching}")
helloyongyang's avatar
helloyongyang committed
267

268
269
270
271
272
    elif config.model_cls == "wan2.1":
        if config.feature_caching == "NoCaching":
            scheduler = WanScheduler(config)
        elif config.feature_caching == "Tea":
            scheduler = WanSchedulerTeaCaching(config)
helloyongyang's avatar
helloyongyang committed
273
        else:
274
            raise NotImplementedError(f"Unsupported feature_caching type: {config.feature_caching}")
helloyongyang's avatar
helloyongyang committed
275
276

    else:
277
        raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
helloyongyang's avatar
helloyongyang committed
278
279
280
    return scheduler


281
def run_main_inference(model, inputs):
helloyongyang's avatar
helloyongyang committed
282
283
284
285
286
287
288
289
290
    for step_index in range(model.scheduler.infer_steps):
        torch.cuda.synchronize()
        time1 = time.time()

        model.scheduler.step_pre(step_index=step_index)

        torch.cuda.synchronize()
        time2 = time.time()

291
        model.infer(inputs)
helloyongyang's avatar
helloyongyang committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307

        torch.cuda.synchronize()
        time3 = time.time()

        model.scheduler.step_post()

        torch.cuda.synchronize()
        time4 = time.time()

        print(f"step {step_index} infer time: {time3 - time2}")
        print(f"step {step_index} all time: {time4 - time1}")
        print("*" * 10)

    return model.scheduler.latents, model.scheduler.generator


308
309
def run_vae(latents, generator, config):
    images = vae_model.decode(latents, generator=generator, config=config)
helloyongyang's avatar
helloyongyang committed
310
311
312
313
314
315
316
317
318
319
    return images


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan"], default="hunyuan")
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--config_path", type=str, default=None)
    parser.add_argument("--image_path", type=str, default=None)
Dongz's avatar
Dongz committed
320
    parser.add_argument("--save_video_path", type=str, default="./output_ligthx2v.mp4")
helloyongyang's avatar
helloyongyang committed
321
322
323
324
325
326
327
328
329
    parser.add_argument("--prompt", type=str, required=True)
    parser.add_argument("--infer_steps", type=int, required=True)
    parser.add_argument("--target_video_length", type=int, required=True)
    parser.add_argument("--target_width", type=int, required=True)
    parser.add_argument("--target_height", type=int, required=True)
    parser.add_argument("--attention_type", type=str, required=True)
    parser.add_argument("--sample_neg_prompt", type=str, default="")
    parser.add_argument("--sample_guide_scale", type=float, default=5.0)
    parser.add_argument("--sample_shift", type=float, default=5.0)
Dongz's avatar
Dongz committed
330
331
332
333
334
    parser.add_argument("--do_mm_calib", action="store_true")
    parser.add_argument("--cpu_offload", action="store_true")
    parser.add_argument("--feature_caching", choices=["NoCaching", "TaylorSeer", "Tea"], default="NoCaching")
    parser.add_argument("--mm_config", default=None)
    parser.add_argument("--seed", type=int, default=42)
Xinchi Huang's avatar
Xinchi Huang committed
335
    parser.add_argument("--parallel_attn_type", default=None, choices=["ulysses", "ring"])
Dongz's avatar
Dongz committed
336
337
338
339
    parser.add_argument("--parallel_vae", action="store_true")
    parser.add_argument("--max_area", action="store_true")
    parser.add_argument("--vae_stride", default=(4, 8, 8))
    parser.add_argument("--patch_size", default=(1, 2, 2))
helloyongyang's avatar
helloyongyang committed
340
341
    parser.add_argument("--teacache_thresh", type=float, default=0.26)
    parser.add_argument("--use_ret_steps", action="store_true", default=False)
lijiaqi2's avatar
lijiaqi2 committed
342
343
344
    parser.add_argument("--use_bfloat16", action="store_true", default=True)
    parser.add_argument("--lora_path", type=str, default=None)
    parser.add_argument("--strength_model", type=float, default=1.0)
helloyongyang's avatar
helloyongyang committed
345
346
347
348
    args = parser.parse_args()

    start_time = time.time()
    print(f"args: {args}")
Dongz's avatar
Dongz committed
349

helloyongyang's avatar
helloyongyang committed
350
351
    seed_all(args.seed)

352
353
354
    config = set_config(args)

    if config.parallel_attn_type:
Dongz's avatar
Dongz committed
355
        dist.init_process_group(backend="nccl")
helloyongyang's avatar
helloyongyang committed
356

357
    print(f"config: {config}")
helloyongyang's avatar
helloyongyang committed
358

lijiaqi2's avatar
lijiaqi2 committed
359
    with time_duration("Load models"):
360
        model, text_encoders, vae_model, image_encoder = load_models(config)
361

362
363
    if config["task"] in ["i2v"]:
        image_encoder_output = run_image_encoder(config, image_encoder, vae_model)
helloyongyang's avatar
helloyongyang committed
364
365
366
    else:
        image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}

367
    with time_duration("Run Text Encoder"):
368
369
370
        text_encoder_output = run_text_encoder(config["prompt"], text_encoders, config, image_encoder_output)

    inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
helloyongyang's avatar
helloyongyang committed
371

372
373
    set_target_shape(config, image_encoder_output)
    scheduler = init_scheduler(config, image_encoder_output)
helloyongyang's avatar
helloyongyang committed
374
375
376
377
378

    model.set_scheduler(scheduler)

    gc.collect()
    torch.cuda.empty_cache()
379
    latents, generator = run_main_inference(model, inputs)
helloyongyang's avatar
helloyongyang committed
380

381
    if config.cpu_offload:
382
383
384
        scheduler.clear()
        del text_encoder_output, image_encoder_output, model, text_encoders, scheduler
        torch.cuda.empty_cache()
helloyongyang's avatar
helloyongyang committed
385

386
    with time_duration("Run VAE"):
387
        images = run_vae(latents, generator, config)
helloyongyang's avatar
helloyongyang committed
388

389
    if not config.parallel_attn_type or (config.parallel_attn_type and dist.get_rank() == 0):
390
        with time_duration("Save video"):
391
392
            if config.model_cls == "wan2.1":
                cache_video(tensor=images, save_file=config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
393
            else:
394
                save_videos_grid(images, config.save_video_path, fps=24)
helloyongyang's avatar
helloyongyang committed
395
396

    end_time = time.time()
397
    print(f"Total cost: {end_time - start_time}")