wan_vace_runner.py 12 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import gc

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image

from lightx2v.models.input_encoders.hf.vace.vace_processor import VaceVideoProcessor
from lightx2v.models.networks.wan.vace_model import WanVaceModel
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner, build_wan_model_with_lora
from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER


@RUNNER_REGISTER("wan2.1_vace")
class WanVaceRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)
        assert self.config["task"] == "vace"
        self.vid_proc = VaceVideoProcessor(
            downsample=tuple([x * y for x, y in zip(self.config["vae_stride"], self.config["patch_size"])]),
            min_area=720 * 1280,
            max_area=720 * 1280,
            min_fps=self.config["fps"] if "fps" in self.config else 16,
            max_fps=self.config["fps"] if "fps" in self.config else 16,
            zero_start=True,
            seq_len=75600,
            keep_last=True,
        )

    def load_transformer(self):
        wan_model_kwargs = {"model_path": self.config["model_path"], "config": self.config, "device": self.init_device}
        lora_configs = self.config.get("lora_configs")
        if not lora_configs:
            model = WanVaceModel(**wan_model_kwargs)
        else:
            model = build_wan_model_with_lora(WanVaceModel, self.config, wan_model_kwargs, lora_configs, model_type="wan2.1")
        return model

    def prepare_source(self, src_video, src_mask, src_ref_images, image_size, device=torch.device("cuda")):
        area = image_size[0] * image_size[1]
        self.vid_proc.set_area(area)
        if area == 720 * 1280:
            self.vid_proc.set_seq_len(75600)
        elif area == 480 * 832:
            self.vid_proc.set_seq_len(32760)
        else:
            raise NotImplementedError(f"image_size {image_size} is not supported")

        image_size = (image_size[1], image_size[0])
        image_sizes = []
        for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
            if sub_src_mask is not None and sub_src_video is not None:
                src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask)
                src_video[i] = src_video[i].to(device)
                src_mask[i] = src_mask[i].to(device)
                src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
                image_sizes.append(src_video[i].shape[2:])
            elif sub_src_video is None:
                src_video[i] = torch.zeros((3, self.config["target_video_length"], image_size[0], image_size[1]), device=device)
                src_mask[i] = torch.ones_like(src_video[i], device=device)
                image_sizes.append(image_size)
            else:
                src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video)
                src_video[i] = src_video[i].to(device)
                src_mask[i] = torch.ones_like(src_video[i], device=device)
                image_sizes.append(src_video[i].shape[2:])

        for i, ref_images in enumerate(src_ref_images):
            if ref_images is not None:
                image_size = image_sizes[i]
                for j, ref_img in enumerate(ref_images):
                    if ref_img is not None:
                        ref_img = Image.open(ref_img).convert("RGB")
                        ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
                        if ref_img.shape[-2:] != image_size:
                            canvas_height, canvas_width = image_size
                            ref_height, ref_width = ref_img.shape[-2:]
                            white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device)  # [-1, 1]
                            scale = min(canvas_height / ref_height, canvas_width / ref_width)
                            new_height = int(ref_height * scale)
                            new_width = int(ref_width * scale)
                            resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode="bilinear", align_corners=False).squeeze(0).unsqueeze(1)
                            top = (canvas_height - new_height) // 2
                            left = (canvas_width - new_width) // 2
                            white_canvas[:, :, top : top + new_height, left : left + new_width] = resized_image
                            ref_img = white_canvas
                        src_ref_images[i][j] = ref_img.to(device)
        return src_video, src_mask, src_ref_images

    @ProfilingContext4DebugL1(
        "Run VAE Encoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration,
        metrics_labels=["WanVaceRunner"],
    )
    def run_vae_encoder(self, frames, ref_images, masks):
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
            self.vae_encoder = self.load_vae_encoder()
        if ref_images is None:
            ref_images = [None] * len(frames)
        else:
            assert len(frames) == len(ref_images)

        if masks is None:
            latents = [self.vae_encoder.encode(frame.unsqueeze(0).to(GET_DTYPE())) for frame in frames]
        else:
            masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks]
            inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
            reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
            inactive = [self.vae_encoder.encode(inact.unsqueeze(0).to(GET_DTYPE())) for inact in inactive]
            reactive = [self.vae_encoder.encode(react.unsqueeze(0).to(GET_DTYPE())) for react in reactive]
            latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]

        cat_latents = []
        for latent, refs in zip(latents, ref_images):
            if refs is not None:
                if masks is None:
                    ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0).to(GET_DTYPE())) for ref in refs]
                else:
                    ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0).to(GET_DTYPE())) for ref in refs]
                    ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
                assert all([x.shape[1] == 1 for x in ref_latent])
                latent = torch.cat([*ref_latent, latent], dim=1)
            cat_latents.append(latent)
        self.latent_shape = list(cat_latents[0].shape)
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
        return self.get_vae_encoder_output(cat_latents, masks, ref_images), self.set_input_info_latent_shape()

    def get_vae_encoder_output(self, cat_latents, masks, ref_images):
        if ref_images is None:
            ref_images = [None] * len(masks)
        else:
            assert len(masks) == len(ref_images)

        result_masks = []
        for mask, refs in zip(masks, ref_images):
            c, depth, height, width = mask.shape
            new_depth = int((depth + 3) // self.config["vae_stride"][0])
            height = 2 * (int(height) // (self.config["vae_stride"][1] * 2))
            width = 2 * (int(width) // (self.config["vae_stride"][2] * 2))

            # reshape
            mask = mask[0, :, :, :]
            mask = mask.view(depth, height, self.config["vae_stride"][1], width, self.config["vae_stride"][1])  # depth, height, 8, width, 8
            mask = mask.permute(2, 4, 0, 1, 3)  # 8, 8, depth, height, width
            mask = mask.reshape(self.config["vae_stride"][1] * self.config["vae_stride"][2], depth, height, width)  # 8*8, depth, height, width

            # interpolation
            mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode="nearest-exact").squeeze(0)

            if refs is not None:
                length = len(refs)
                mask_pad = torch.zeros_like(mask[:, :length, :, :])
                mask = torch.cat((mask_pad, mask), dim=1)
            result_masks.append(mask)

        return [torch.cat([zz, mm], dim=0) for zz, mm in zip(cat_latents, result_masks)]

    def set_input_info_latent_shape(self):
        latent_shape = self.latent_shape
        latent_shape[0] = int(latent_shape[0] / 2)
        return latent_shape

    @ProfilingContext4DebugL1(
        "Run VAE Decoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
        metrics_labels=["WanVaceRunner"],
    )
    def run_vae_decoder(self, latents):
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
            self.vae_decoder = self.load_vae_decoder()

        if self.src_ref_images is not None:
            assert len(self.src_ref_images) == 1
            refs = self.src_ref_images[0]
            if refs is not None:
                latents = latents[:, len(refs) :, :, :]

        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))

        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
            del self.vae_decoder
            torch.cuda.empty_cache()
            gc.collect()

        return images


@RUNNER_REGISTER("wan2.2_moe_vace")
class Wan22MoeVaceRunner(WanVaceRunner):
    def __init__(self, config):
        super().__init__(config)
        if self.config.get("dit_quantized", False) and self.config.get("high_noise_quantized_ckpt", None):
            self.high_noise_model_path = self.config["high_noise_quantized_ckpt"]
        elif self.config.get("high_noise_original_ckpt", None):
            self.high_noise_model_path = self.config["high_noise_original_ckpt"]
        else:
            self.high_noise_model_path = os.path.join(self.config["model_path"], "high_noise_model")
            if not os.path.isdir(self.high_noise_model_path):
                raise FileNotFoundError(f"High Noise Model does not find")

        if self.config.get("dit_quantized", False) and self.config.get("low_noise_quantized_ckpt", None):
            self.low_noise_model_path = self.config["low_noise_quantized_ckpt"]
        elif not self.config.get("dit_quantized", False) and self.config.get("low_noise_original_ckpt", None):
            self.low_noise_model_path = self.config["low_noise_original_ckpt"]
        else:
            self.low_noise_model_path = os.path.join(self.config["model_path"], "low_noise_model")
            if not os.path.isdir(self.low_noise_model_path):
                raise FileNotFoundError(f"Low Noise Model does not find")

    def load_transformer(self):
        lora_configs = self.config.get("lora_configs")
        high_model_kwargs = {
            "model_path": self.high_noise_model_path,
            "config": self.config,
            "device": self.init_device,
            "model_type": "wan2.2_moe_high_noise",
        }
        low_model_kwargs = {
            "model_path": self.low_noise_model_path,
            "config": self.config,
            "device": self.init_device,
            "model_type": "wan2.2_moe_low_noise",
        }
        if not lora_configs:
            high_noise_model = WanVaceModel(**high_model_kwargs)
            low_noise_model = WanVaceModel(**low_model_kwargs)
        else:
            high_noise_model = build_wan_model_with_lora(WanVaceModel, self.config, high_model_kwargs, lora_configs, model_type="high_noise_model")
            low_noise_model = build_wan_model_with_lora(WanVaceModel, self.config, low_model_kwargs, lora_configs, model_type="low_noise_model")

        return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary"])