__main__.py 16.5 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
import argparse
import torch
import torch.distributed as dist
import os
import time
import gc
import json
helloyongyang's avatar
helloyongyang committed
8
import torchvision
helloyongyang's avatar
helloyongyang committed
9
10
11
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image
12
13
14
15
16

from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.set_config import set_config

17
18
19
20
21
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
helloyongyang's avatar
helloyongyang committed
22

23
24
25
26
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.feature_caching.scheduler import WanSchedulerTeaCaching
helloyongyang's avatar
helloyongyang committed
27

28
29
30
from lightx2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
lijiaqi2's avatar
lijiaqi2 committed
31

32
33
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
helloyongyang's avatar
helloyongyang committed
34

35
from lightx2v.common.ops import *
lijiaqi2's avatar
lijiaqi2 committed
36
37


38
39
def load_models(config):
    if config["parallel_attn_type"]:
helloyongyang's avatar
helloyongyang committed
40
41
42
        cur_rank = dist.get_rank()  # 获取当前进程的 rank
        torch.cuda.set_device(cur_rank)  # 设置当前进程的 CUDA 设备
    image_encoder = None
43
    if config.cpu_offload:
helloyongyang's avatar
helloyongyang committed
44
45
46
47
        init_device = torch.device("cpu")
    else:
        init_device = torch.device("cuda")

48
49
50
    if config.model_cls == "hunyuan":
        if config.task == "t2v":
            text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(config.model_path, "text_encoder"), init_device)
helloyongyang's avatar
helloyongyang committed
51
        else:
52
53
            text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(config.model_path, "text_encoder_i2v"), init_device)
        text_encoder_2 = TextEncoderHFClipModel(os.path.join(config.model_path, "text_encoder_2"), init_device)
helloyongyang's avatar
helloyongyang committed
54
        text_encoders = [text_encoder_1, text_encoder_2]
55
56
        model = HunyuanModel(config.model_path, config, init_device, config)
        vae_model = VideoEncoderKLCausal3DModel(config.model_path, dtype=torch.float16, device=init_device, config=config)
helloyongyang's avatar
helloyongyang committed
57

58
    elif config.model_cls == "wan2.1":
59
        with ProfilingContext("Load Text Encoder"):
60
            text_encoder = T5EncoderModel(
61
                text_len=config["text_len"],
62
63
                dtype=torch.bfloat16,
                device=init_device,
64
65
                checkpoint_path=os.path.join(config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
                tokenizer_path=os.path.join(config.model_path, "google/umt5-xxl"),
66
67
68
                shard_fn=None,
            )
            text_encoders = [text_encoder]
69
        with ProfilingContext("Load Wan Model"):
70
            model = WanModel(config.model_path, config, init_device)
lijiaqi2's avatar
lijiaqi2 committed
71

72
        if config.lora_path:
lijiaqi2's avatar
lijiaqi2 committed
73
            lora_wrapper = WanLoraWrapper(model)
74
            with ProfilingContext("Load LoRA Model"):
75
76
                lora_name = lora_wrapper.load_lora(config.lora_path)
                lora_wrapper.apply_lora(lora_name, config.strength_model)
lijiaqi2's avatar
lijiaqi2 committed
77
78
                print(f"Loaded LoRA: {lora_name}")

79
        with ProfilingContext("Load WAN VAE Model"):
80
81
            vae_model = WanVAE(vae_pth=os.path.join(config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=config.parallel_vae)
        if config.task == "i2v":
82
            with ProfilingContext("Load Image Encoder"):
lijiaqi2's avatar
lijiaqi2 committed
83
84
85
                image_encoder = CLIPModel(
                    dtype=torch.float16,
                    device=init_device,
86
87
                    checkpoint_path=os.path.join(config.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
                    tokenizer_path=os.path.join(config.model_path, "xlm-roberta-large"),
lijiaqi2's avatar
lijiaqi2 committed
88
                )
helloyongyang's avatar
helloyongyang committed
89
    else:
90
        raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
helloyongyang's avatar
helloyongyang committed
91
92
93
94

    return model, text_encoders, vae_model, image_encoder


95
96
97
def set_target_shape(config, image_encoder_output):
    if config.model_cls == "hunyuan":
        if config.task == "t2v":
helloyongyang's avatar
helloyongyang committed
98
            vae_scale_factor = 2 ** (4 - 1)
99
            config.target_shape = (
helloyongyang's avatar
helloyongyang committed
100
101
                1,
                16,
102
103
104
                (config.target_video_length - 1) // 4 + 1,
                int(config.target_height) // vae_scale_factor,
                int(config.target_width) // vae_scale_factor,
helloyongyang's avatar
helloyongyang committed
105
            )
106
        elif config.task == "i2v":
helloyongyang's avatar
helloyongyang committed
107
            vae_scale_factor = 2 ** (4 - 1)
108
            config.target_shape = (
helloyongyang's avatar
helloyongyang committed
109
110
                1,
                16,
111
                (config.target_video_length - 1) // 4 + 1,
helloyongyang's avatar
helloyongyang committed
112
113
114
                int(image_encoder_output["target_height"]) // vae_scale_factor,
                int(image_encoder_output["target_width"]) // vae_scale_factor,
            )
115
116
117
118
119
    elif config.model_cls == "wan2.1":
        if config.task == "i2v":
            config.target_shape = (16, 21, config.lat_h, config.lat_w)
        elif config.task == "t2v":
            config.target_shape = (
helloyongyang's avatar
helloyongyang committed
120
                16,
121
122
123
                (config.target_video_length - 1) // 4 + 1,
                int(config.target_height) // config.vae_stride[1],
                int(config.target_width) // config.vae_stride[2],
helloyongyang's avatar
helloyongyang committed
124
125
126
            )


helloyongyang's avatar
helloyongyang committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
    num_patches = round((base_size / patch_size) ** 2)
    assert max_ratio >= 1.0
    crop_size_list = []
    wp, hp = num_patches, 1
    while wp > 0:
        if max(wp, hp) / min(wp, hp) <= max_ratio:
            crop_size_list.append((wp * patch_size, hp * patch_size))
        if (hp + 1) * wp <= num_patches:
            hp += 1
        else:
            wp -= 1
    return crop_size_list


def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
    aspect_ratio = float(height) / float(width)
    diff_ratios = ratios - aspect_ratio

    if aspect_ratio >= 1:
        indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
    else:
        indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]

    closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
    closest_size = buckets[closest_ratio_id]
    closest_ratio = ratios[closest_ratio_id]

    return closest_size, closest_ratio


158
159
160
def run_image_encoder(config, image_encoder, vae_model):
    if config.model_cls == "hunyuan":
        img = Image.open(config.image_path).convert("RGB")
helloyongyang's avatar
helloyongyang committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        origin_size = img.size

        i2v_resolution = "720p"
        if i2v_resolution == "720p":
            bucket_hw_base_size = 960
        elif i2v_resolution == "540p":
            bucket_hw_base_size = 720
        elif i2v_resolution == "360p":
            bucket_hw_base_size = 480
        else:
            raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")

        crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
        aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
        closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)

        resize_param = min(closest_size)
        center_crop_param = closest_size

        ref_image_transform = torchvision.transforms.Compose(
            [torchvision.transforms.Resize(resize_param), torchvision.transforms.CenterCrop(center_crop_param), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
        )

        semantic_image_pixel_values = [ref_image_transform(img)]
        semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))

187
        img_latents = vae_model.encode(semantic_image_pixel_values, config).mode()
helloyongyang's avatar
helloyongyang committed
188
189
190
191
192
193
194
195

        scaling_factor = 0.476986
        img_latents.mul_(scaling_factor)

        target_height, target_width = closest_size

        return {"img": img, "img_latents": img_latents, "target_height": target_height, "target_width": target_width}

196
197
    elif config.model_cls == "wan2.1":
        img = Image.open(config.image_path).convert("RGB")
helloyongyang's avatar
helloyongyang committed
198
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
199
        clip_encoder_out = image_encoder.visual([img[:, None, :, :]], config).squeeze(0).to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
200
201
        h, w = img.shape[1:]
        aspect_ratio = h / w
202
203
204
205
206
        max_area = config.target_height * config.target_width
        lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
        lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
        h = lat_h * config.vae_stride[1]
        w = lat_w * config.vae_stride[2]
helloyongyang's avatar
helloyongyang committed
207

208
209
        config.lat_h = lat_h
        config.lat_w = lat_w
Dongz's avatar
Dongz committed
210
211

        msk = torch.ones(1, 81, lat_h, lat_w, device=torch.device("cuda"))
helloyongyang's avatar
helloyongyang committed
212
        msk[:, 1:] = 0
Dongz's avatar
Dongz committed
213
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
helloyongyang's avatar
helloyongyang committed
214
215
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
gushiqiao's avatar
gushiqiao committed
216
        vae_encode_out = vae_model.encode(
217
            [torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.zeros(3, 80, h, w)], dim=1).cuda()], config
gushiqiao's avatar
gushiqiao committed
218
        )[0]
helloyongyang's avatar
helloyongyang committed
219
220
221
222
        vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
        return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}

    else:
223
        raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
helloyongyang's avatar
helloyongyang committed
224
225


226
def run_text_encoder(text, text_encoders, config, image_encoder_output):
helloyongyang's avatar
helloyongyang committed
227
    text_encoder_output = {}
228
    if config.model_cls == "hunyuan":
helloyongyang's avatar
helloyongyang committed
229
        for i, encoder in enumerate(text_encoders):
230
231
            if config.task == "i2v" and i == 0:
                text_state, attention_mask = encoder.infer(text, image_encoder_output["img"], config)
helloyongyang's avatar
helloyongyang committed
232
            else:
233
                text_state, attention_mask = encoder.infer(text, config)
Dongz's avatar
Dongz committed
234
235
            text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
            text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
helloyongyang's avatar
helloyongyang committed
236

237
238
239
240
    elif config.model_cls == "wan2.1":
        n_prompt = config.get("sample_neg_prompt", "")
        context = text_encoders[0].infer([text], config)
        context_null = text_encoders[0].infer([n_prompt if n_prompt else ""], config)
helloyongyang's avatar
helloyongyang committed
241
242
243
244
        text_encoder_output["context"] = context
        text_encoder_output["context_null"] = context_null

    else:
245
        raise NotImplementedError(f"Unsupported model type: {config.model_cls}")
helloyongyang's avatar
helloyongyang committed
246
247
248
249

    return text_encoder_output


250
251
252
253
254
255
256
257
def init_scheduler(config, image_encoder_output):
    if config.model_cls == "hunyuan":
        if config.feature_caching == "NoCaching":
            scheduler = HunyuanScheduler(config, image_encoder_output)
        elif config.feature_caching == "Tea":
            scheduler = HunyuanSchedulerTeaCaching(config, image_encoder_output)
        elif config.feature_caching == "TaylorSeer":
            scheduler = HunyuanSchedulerTaylorCaching(config, image_encoder_output)
helloyongyang's avatar
helloyongyang committed
258
        else:
259
            raise NotImplementedError(f"Unsupported feature_caching type: {config.feature_caching}")
helloyongyang's avatar
helloyongyang committed
260

261
262
263
264
265
    elif config.model_cls == "wan2.1":
        if config.feature_caching == "NoCaching":
            scheduler = WanScheduler(config)
        elif config.feature_caching == "Tea":
            scheduler = WanSchedulerTeaCaching(config)
helloyongyang's avatar
helloyongyang committed
266
        else:
267
            raise NotImplementedError(f"Unsupported feature_caching type: {config.feature_caching}")
helloyongyang's avatar
helloyongyang committed
268
269

    else:
270
        raise NotImplementedError(f"Unsupported model class: {config.model_cls}")
helloyongyang's avatar
helloyongyang committed
271
272
273
    return scheduler


274
def run_main_inference(model, inputs):
helloyongyang's avatar
helloyongyang committed
275
    for step_index in range(model.scheduler.infer_steps):
276
        print(f"==> step_index: {step_index + 1} / {model.scheduler.infer_steps}")
helloyongyang's avatar
helloyongyang committed
277

278
279
        with ProfilingContext4Debug("step_pre"):
            model.scheduler.step_pre(step_index=step_index)
helloyongyang's avatar
helloyongyang committed
280

281
282
        with ProfilingContext4Debug("infer"):
            model.infer(inputs)
helloyongyang's avatar
helloyongyang committed
283

284
285
        with ProfilingContext4Debug("step_post"):
            model.scheduler.step_post()
helloyongyang's avatar
helloyongyang committed
286
287
288
289

    return model.scheduler.latents, model.scheduler.generator


290
291
def run_vae(latents, generator, config):
    images = vae_model.decode(latents, generator=generator, config=config)
helloyongyang's avatar
helloyongyang committed
292
293
294
295
296
297
298
299
300
301
    return images


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan"], default="hunyuan")
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--config_path", type=str, default=None)
    parser.add_argument("--image_path", type=str, default=None)
Dongz's avatar
Dongz committed
302
    parser.add_argument("--save_video_path", type=str, default="./output_ligthx2v.mp4")
helloyongyang's avatar
helloyongyang committed
303
304
305
306
307
308
309
310
311
    parser.add_argument("--prompt", type=str, required=True)
    parser.add_argument("--infer_steps", type=int, required=True)
    parser.add_argument("--target_video_length", type=int, required=True)
    parser.add_argument("--target_width", type=int, required=True)
    parser.add_argument("--target_height", type=int, required=True)
    parser.add_argument("--attention_type", type=str, required=True)
    parser.add_argument("--sample_neg_prompt", type=str, default="")
    parser.add_argument("--sample_guide_scale", type=float, default=5.0)
    parser.add_argument("--sample_shift", type=float, default=5.0)
Dongz's avatar
Dongz committed
312
313
314
315
316
    parser.add_argument("--do_mm_calib", action="store_true")
    parser.add_argument("--cpu_offload", action="store_true")
    parser.add_argument("--feature_caching", choices=["NoCaching", "TaylorSeer", "Tea"], default="NoCaching")
    parser.add_argument("--mm_config", default=None)
    parser.add_argument("--seed", type=int, default=42)
Xinchi Huang's avatar
Xinchi Huang committed
317
    parser.add_argument("--parallel_attn_type", default=None, choices=["ulysses", "ring"])
Dongz's avatar
Dongz committed
318
319
320
321
    parser.add_argument("--parallel_vae", action="store_true")
    parser.add_argument("--max_area", action="store_true")
    parser.add_argument("--vae_stride", default=(4, 8, 8))
    parser.add_argument("--patch_size", default=(1, 2, 2))
helloyongyang's avatar
helloyongyang committed
322
323
    parser.add_argument("--teacache_thresh", type=float, default=0.26)
    parser.add_argument("--use_ret_steps", action="store_true", default=False)
lijiaqi2's avatar
lijiaqi2 committed
324
325
326
    parser.add_argument("--use_bfloat16", action="store_true", default=True)
    parser.add_argument("--lora_path", type=str, default=None)
    parser.add_argument("--strength_model", type=float, default=1.0)
helloyongyang's avatar
helloyongyang committed
327
328
    args = parser.parse_args()

329
    start_time = time.perf_counter()
helloyongyang's avatar
helloyongyang committed
330
    print(f"args: {args}")
Dongz's avatar
Dongz committed
331

helloyongyang's avatar
helloyongyang committed
332
333
    seed_all(args.seed)

334
335
336
    config = set_config(args)

    if config.parallel_attn_type:
Dongz's avatar
Dongz committed
337
        dist.init_process_group(backend="nccl")
helloyongyang's avatar
helloyongyang committed
338

339
    print(f"config: {config}")
helloyongyang's avatar
helloyongyang committed
340

341
    with ProfilingContext("Load models"):
342
        model, text_encoders, vae_model, image_encoder = load_models(config)
343

344
345
    if config["task"] in ["i2v"]:
        image_encoder_output = run_image_encoder(config, image_encoder, vae_model)
helloyongyang's avatar
helloyongyang committed
346
347
348
    else:
        image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}

349
    with ProfilingContext("Run Text Encoder"):
350
351
352
        text_encoder_output = run_text_encoder(config["prompt"], text_encoders, config, image_encoder_output)

    inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
helloyongyang's avatar
helloyongyang committed
353

354
355
    set_target_shape(config, image_encoder_output)
    scheduler = init_scheduler(config, image_encoder_output)
helloyongyang's avatar
helloyongyang committed
356
357
358
359
360

    model.set_scheduler(scheduler)

    gc.collect()
    torch.cuda.empty_cache()
361
    latents, generator = run_main_inference(model, inputs)
helloyongyang's avatar
helloyongyang committed
362

363
    if config.cpu_offload:
364
365
366
        scheduler.clear()
        del text_encoder_output, image_encoder_output, model, text_encoders, scheduler
        torch.cuda.empty_cache()
helloyongyang's avatar
helloyongyang committed
367

368
    with ProfilingContext("Run VAE"):
369
        images = run_vae(latents, generator, config)
helloyongyang's avatar
helloyongyang committed
370

371
    if not config.parallel_attn_type or (config.parallel_attn_type and dist.get_rank() == 0):
372
        with ProfilingContext("Save video"):
373
374
            if config.model_cls == "wan2.1":
                cache_video(tensor=images, save_file=config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
375
            else:
376
                save_videos_grid(images, config.save_video_path, fps=24)
helloyongyang's avatar
helloyongyang committed
377

378
    end_time = time.perf_counter()
379
    print(f"Total cost: {end_time - start_time}")