wan_causvid_runner.py 7.25 KB
Newer Older
1
import os
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
2
import gc
3
4
5
6
7
8
9
10
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
11
from lightx2v.models.schedulers.wan.causvid.scheduler import WanCausVidScheduler
12
13
14
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
15
from lightx2v.models.networks.wan.causvid_model import WanCausVidModel
16
17
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
root's avatar
root committed
18
from loguru import logger
19
20
21
import torch.distributed as dist


Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
22
23
@RUNNER_REGISTER("wan2.1_causvid")
class WanCausVidRunner(WanRunner):
24
25
26
27
28
29
30
31
    def __init__(self, config):
        super().__init__(config)
        self.num_frame_per_block = self.model.config.num_frame_per_block
        self.num_frames = self.model.config.num_frames
        self.frame_seq_length = self.model.config.frame_seq_length
        self.infer_blocks = self.model.config.num_blocks
        self.num_fragments = self.model.config.num_fragments

32
33
34
35
36
37
38
    def load_transformer(self):
        if self.config.cpu_offload:
            init_device = torch.device("cpu")
        else:
            init_device = torch.device("cuda")
        return WanCausVidModel(self.config.model_path, self.config, init_device)

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    @ProfilingContext("Load models")
    def load_model(self):
        if self.config["parallel_attn_type"]:
            cur_rank = dist.get_rank()
            torch.cuda.set_device(cur_rank)
        image_encoder = None
        if self.config.cpu_offload:
            init_device = torch.device("cpu")
        else:
            init_device = torch.device("cuda")

        text_encoder = T5EncoderModel(
            text_len=self.config["text_len"],
            dtype=torch.bfloat16,
            device=init_device,
            checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
            tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
            shard_fn=None,
        )
        text_encoders = [text_encoder]
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
59
        model = WanCausVidModel(self.config.model_path, self.config, init_device)
60
61
62
63
64

        if self.config.lora_path:
            lora_wrapper = WanLoraWrapper(model)
            lora_name = lora_wrapper.load_lora(self.config.lora_path)
            lora_wrapper.apply_lora(lora_name, self.config.strength_model)
root's avatar
root committed
65
            logger.info(f"Loaded LoRA: {lora_name}")
66
67
68
69
70
71
72
73
74
75
76
77

        vae_model = WanVAE(vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=self.config.parallel_vae)
        if self.config.task == "i2v":
            image_encoder = CLIPModel(
                dtype=torch.float16,
                device=init_device,
                checkpoint_path=os.path.join(self.config.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
                tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
            )

        return model, text_encoders, vae_model, image_encoder

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
78
79
80
81
82
    def set_inputs(self, inputs):
        super().set_inputs(inputs)
        self.config["num_fragments"] = inputs.get("num_fragments", 1)
        self.num_fragments = self.config["num_fragments"]

83
    def init_scheduler(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
84
        scheduler = WanCausVidScheduler(self.config)
85
86
87
88
        self.model.set_scheduler(scheduler)

    def set_target_shape(self):
        if self.config.task == "i2v":
wangshankun's avatar
wangshankun committed
89
90
91
92
93
            self.config.target_shape = (16, self.config.num_frame_per_block, self.config.lat_h, self.config.lat_w)
            # i2v需根据input shape重置frame_seq_length
            frame_seq_length = (self.config.lat_h // 2) * (self.config.lat_w // 2)
            self.model.transformer_infer.frame_seq_length = frame_seq_length
            self.frame_seq_length = frame_seq_length
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        elif self.config.task == "t2v":
            self.config.target_shape = (
                16,
                self.config.num_frame_per_block,
                int(self.config.target_height) // self.config.vae_stride[1],
                int(self.config.target_width) // self.config.vae_stride[2],
            )

    def run(self):
        self.model.transformer_infer._init_kv_cache(dtype=torch.bfloat16, device="cuda")
        self.model.transformer_infer._init_crossattn_cache(dtype=torch.bfloat16, device="cuda")

        output_latents = torch.zeros(
            (self.model.config.target_shape[0], self.num_frames + (self.num_fragments - 1) * (self.num_frames - self.num_frame_per_block), *self.model.config.target_shape[2:]),
            device="cuda",
            dtype=torch.bfloat16,
        )

        start_block_idx = 0

        for fragment_idx in range(self.num_fragments):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
115
            logger.info(f"========> fragment_idx: {fragment_idx + 1} / {self.num_fragments}")
116
117
118
119
120

            kv_start = 0
            kv_end = kv_start + self.num_frame_per_block * self.frame_seq_length

            if fragment_idx > 0:
root's avatar
root committed
121
                logger.info("recompute the kv_cache ...")
122
123
124
125
126
127
128
129
130
131
132
133
134
                with ProfilingContext4Debug("step_pre"):
                    self.model.scheduler.latents = self.model.scheduler.last_sample
                    self.model.scheduler.step_pre(step_index=self.model.scheduler.infer_steps - 1)

                with ProfilingContext4Debug("infer"):
                    self.model.infer(self.inputs, kv_start, kv_end)

                kv_start += self.num_frame_per_block * self.frame_seq_length
                kv_end += self.num_frame_per_block * self.frame_seq_length

            infer_blocks = self.infer_blocks - (fragment_idx > 0)

            for block_idx in range(infer_blocks):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
135
136
                logger.info(f"=====> block_idx: {block_idx + 1} / {infer_blocks}")
                logger.info(f"=====> kv_start: {kv_start}, kv_end: {kv_end}")
137
138
139
                self.model.scheduler.reset()

                for step_index in range(self.model.scheduler.infer_steps):
root's avatar
root committed
140
                    logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

                    with ProfilingContext4Debug("step_pre"):
                        self.model.scheduler.step_pre(step_index=step_index)

                    with ProfilingContext4Debug("infer"):
                        self.model.infer(self.inputs, kv_start, kv_end)

                    with ProfilingContext4Debug("step_post"):
                        self.model.scheduler.step_post()

                kv_start += self.num_frame_per_block * self.frame_seq_length
                kv_end += self.num_frame_per_block * self.frame_seq_length

                output_latents[:, start_block_idx * self.num_frame_per_block : (start_block_idx + 1) * self.num_frame_per_block] = self.model.scheduler.latents
                start_block_idx += 1

        return output_latents, self.model.scheduler.generator
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
158
159
160
161
162
163

    def end_run(self):
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler, self.model.transformer_infer.kv_cache, self.model.transformer_infer.crossattn_cache
        gc.collect()
        torch.cuda.empty_cache()