wan_animate_runner.py 17.9 KB
Newer Older
1
2
3
4
5
6
7
import gc
from copy import deepcopy

import cv2
import numpy as np
import torch
import torch.nn.functional as F
gushiqiao's avatar
gushiqiao committed
8
9
10
11
12
13
from loguru import logger

try:
    from decord import VideoReader
except ImportError:
    VideoReader = None
14
    logger.info("If you want to run animate model, please install decord.")
gushiqiao's avatar
gushiqiao committed
15

16
17
18
19

from lightx2v.models.input_encoders.hf.animate.face_encoder import FaceEncoder
from lightx2v.models.input_encoders.hf.animate.motion_encoder import Generator
from lightx2v.models.networks.wan.animate_model import WanAnimateModel
20
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
21
from lightx2v.models.runners.wan.wan_runner import WanRunner
yihuiwen's avatar
yihuiwen committed
22
from lightx2v.server.metrics import monitor_cli
23
24
25
26
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import load_weights, remove_substrings_from_keys
27
from lightx2v_platform.base.global_var import AI_DEVICE
28
29
30
31
32
33


@RUNNER_REGISTER("wan2.2_animate")
class WanAnimateRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)
34
        assert self.config["task"] == "animate"
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

    def inputs_padding(self, array, target_len):
        idx = 0
        flip = False
        target_array = []
        while len(target_array) < target_len:
            target_array.append(deepcopy(array[idx]))
            if flip:
                idx -= 1
            else:
                idx += 1
            if idx == 0 or idx == len(array) - 1:
                flip = not flip
        return target_array[:target_len]

    def get_valid_len(self, real_len, clip_len=81, overlap=1):
        real_clip_len = clip_len - overlap
        last_clip_num = (real_len - overlap) % real_clip_len
        if last_clip_num == 0:
            extra = 0
        else:
            extra = real_clip_len - last_clip_num
        target_len = real_len + extra
        return target_len

    def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
        if mask_pixel_values is None:
            msk = torch.zeros(1, (lat_t - 1) * 4 + 1, lat_h, lat_w, dtype=GET_DTYPE(), device=device)
        else:
            msk = mask_pixel_values.clone()
        msk[:, :mask_len] = 1
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
        return msk

    def padding_resize(
        self,
        img_ori,
        height=512,
        width=512,
        padding_color=(0, 0, 0),
        interpolation=cv2.INTER_LINEAR,
    ):
        ori_height = img_ori.shape[0]
        ori_width = img_ori.shape[1]
        channel = img_ori.shape[2]

        img_pad = np.zeros((height, width, channel))
        if channel == 1:
            img_pad[:, :, 0] = padding_color[0]
        else:
            img_pad[:, :, 0] = padding_color[0]
            img_pad[:, :, 1] = padding_color[1]
            img_pad[:, :, 2] = padding_color[2]

        if (ori_height / ori_width) > (height / width):
            new_width = int(height / ori_height * ori_width)
            img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)
            padding = int((width - new_width) / 2)
            if len(img.shape) == 2:
                img = img[:, :, np.newaxis]
            img_pad[:, padding : padding + new_width, :] = img
        else:
            new_height = int(width / ori_width * ori_height)
            img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)
            padding = int((height - new_height) / 2)
            if len(img.shape) == 2:
                img = img[:, :, np.newaxis]
            img_pad[padding : padding + new_height, :, :] = img

        img_pad = np.uint8(img_pad)

        return img_pad

    def prepare_source(self, src_pose_path, src_face_path, src_ref_path):
        pose_video_reader = VideoReader(src_pose_path)
        pose_len = len(pose_video_reader)
        pose_idxs = list(range(pose_len))
        cond_images = pose_video_reader.get_batch(pose_idxs).asnumpy()

        face_video_reader = VideoReader(src_face_path)
        face_len = len(face_video_reader)
        face_idxs = list(range(face_len))
        face_images = face_video_reader.get_batch(face_idxs).asnumpy()
        height, width = cond_images[0].shape[:2]
        refer_images = cv2.imread(src_ref_path)[..., ::-1]
        refer_images = self.padding_resize(refer_images, height=height, width=width)
        return cond_images, face_images, refer_images

    def prepare_source_for_replace(self, src_bg_path, src_mask_path):
        bg_video_reader = VideoReader(src_bg_path)
        bg_len = len(bg_video_reader)
        bg_idxs = list(range(bg_len))
        bg_images = bg_video_reader.get_batch(bg_idxs).asnumpy()

        mask_video_reader = VideoReader(src_mask_path)
        mask_len = len(mask_video_reader)
        mask_idxs = list(range(mask_len))
        mask_images = mask_video_reader.get_batch(mask_idxs).asnumpy()
        mask_images = mask_images[:, :, :, 0] / 255
        return bg_images, mask_images

    @ProfilingContext4DebugL2("Run Image Encoders")
    def run_image_encoders(
        self,
        conditioning_pixel_values,
        refer_t_pixel_values,
        bg_pixel_values,
        mask_pixel_values,
        face_pixel_values,
    ):
        clip_encoder_out = self.run_image_encoder(self.refer_pixel_values)
        vae_encoder_out, pose_latents = self.run_vae_encoder(
            conditioning_pixel_values,
            refer_t_pixel_values,
            bg_pixel_values,
            mask_pixel_values,
        )
        return {"image_encoder_output": {"clip_encoder_out": clip_encoder_out, "vae_encoder_out": vae_encoder_out, "pose_latents": pose_latents, "face_pixel_values": face_pixel_values}}

yihuiwen's avatar
yihuiwen committed
156
157
158
    @ProfilingContext4DebugL1(
        "Run VAE Encoder",
        recorder_mode=GET_RECORDER_MODE(),
159
        metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration,
yihuiwen's avatar
yihuiwen committed
160
161
        metrics_labels=["WanAnimateRunner"],
    )
162
163
164
165
166
167
168
169
170
171
172
    def run_vae_encoder(
        self,
        conditioning_pixel_values,
        refer_t_pixel_values,
        bg_pixel_values,
        mask_pixel_values,
    ):
        H, W = self.refer_pixel_values.shape[-2], self.refer_pixel_values.shape[-1]
        pose_latents = self.vae_encoder.encode(conditioning_pixel_values.unsqueeze(0))  #  c t h w
        ref_latents = self.vae_encoder.encode(self.refer_pixel_values.unsqueeze(1).unsqueeze(0))  #  c t h w

173
        mask_ref = self.get_i2v_mask(1, self.latent_h, self.latent_w, 1)
174
175
176
        y_ref = torch.concat([mask_ref, ref_latents])

        if self.mask_reft_len > 0:
177
            if self.config["replace_flag"]:
178
179
180
181
182
183
184
185
                y_reft = self.vae_encoder.encode(
                    torch.concat(
                        [
                            refer_t_pixel_values.unsqueeze(2)[0, :, : self.mask_reft_len],
                            bg_pixel_values[:, self.mask_reft_len :],
                        ],
                        dim=1,
                    )
186
                    .to(AI_DEVICE)
187
188
189
190
191
192
193
194
                    .unsqueeze(0)
                )
                mask_pixel_values = 1 - mask_pixel_values
                mask_pixel_values = mask_pixel_values.permute(1, 0, 2, 3)
                mask_pixel_values = F.interpolate(mask_pixel_values, size=(H // 8, W // 8), mode="nearest")
                mask_pixel_values = mask_pixel_values[:, 0, :, :]

                msk_reft = self.get_i2v_mask(
195
196
197
                    self.latent_t,
                    self.latent_h,
                    self.latent_w,
198
199
200
201
202
203
204
205
206
207
208
209
                    self.mask_reft_len,
                    mask_pixel_values=mask_pixel_values.unsqueeze(0),
                )
            else:
                y_reft = self.vae_encoder.encode(
                    torch.concat(
                        [
                            torch.nn.functional.interpolate(
                                refer_t_pixel_values.unsqueeze(2)[0, :, : self.mask_reft_len].cpu(),
                                size=(H, W),
                                mode="bicubic",
                            ),
210
                            torch.zeros(3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE()),
211
212
213
                        ],
                        dim=1,
                    )
214
                    .to(AI_DEVICE)
215
216
                    .unsqueeze(0)
                )
217
                msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len)
218
        else:
219
            if self.config["replace_flag"]:
220
221
222
223
224
225
                mask_pixel_values = 1 - mask_pixel_values
                mask_pixel_values = mask_pixel_values.permute(1, 0, 2, 3)
                mask_pixel_values = F.interpolate(mask_pixel_values, size=(H // 8, W // 8), mode="nearest")
                mask_pixel_values = mask_pixel_values[:, 0, :, :]
                y_reft = self.vae_encoder.encode(bg_pixel_values.unsqueeze(0))
                msk_reft = self.get_i2v_mask(
226
227
228
                    self.latent_t,
                    self.latent_h,
                    self.latent_w,
229
230
231
232
                    self.mask_reft_len,
                    mask_pixel_values=mask_pixel_values.unsqueeze(0),
                )
            else:
233
234
                y_reft = self.vae_encoder.encode(torch.zeros(1, 3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE(), device="cuda"))
                msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len)
235
236
237
238
239
240
241

        y_reft = torch.concat([msk_reft, y_reft])
        y = torch.concat([y_ref, y_reft], dim=1)

        return y, pose_latents

    def prepare_input(self):
242
243
244
        src_pose_path = self.input_info.src_pose_path
        src_face_path = self.input_info.src_face_path
        src_ref_path = self.input_info.src_ref_images
245
246
        self.cond_images, self.face_images, self.refer_images = self.prepare_source(src_pose_path, src_face_path, src_ref_path)
        self.refer_pixel_values = torch.tensor(self.refer_images / 127.5 - 1, dtype=GET_DTYPE(), device="cuda").permute(2, 0, 1)  # chw
247
248
249
250
        self.latent_t = self.config["target_video_length"] // self.config["vae_stride"][0] + 1
        self.latent_h = self.refer_pixel_values.shape[-2] // self.config["vae_stride"][1]
        self.latent_w = self.refer_pixel_values.shape[-1] // self.config["vae_stride"][2]
        self.input_info.latent_shape = [self.config.get("num_channels_latents", 16), self.latent_t + 1, self.latent_h, self.latent_w]
251
252
253
        self.real_frame_len = len(self.cond_images)
        target_len = self.get_valid_len(
            self.real_frame_len,
254
255
            self.config["target_video_length"],
            overlap=self.config["refert_num"] if "refert_num" in self.config else 1,
256
257
258
259
260
        )
        logger.info("real frames: {} target frames: {}".format(self.real_frame_len, target_len))
        self.cond_images = self.inputs_padding(self.cond_images, target_len)
        self.face_images = self.inputs_padding(self.face_images, target_len)

261
        if self.config["replace_flag"] if "replace_flag" in self.config else False:
262
263
            src_bg_path = self.input_info.src_bg_path
            src_mask_path = self.input_info.src_mask_path
264
265
266
267
268
269
            self.bg_images, self.mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path)
            self.bg_images = self.inputs_padding(self.bg_images, target_len)
            self.mask_images = self.inputs_padding(self.mask_images, target_len)

    def get_video_segment_num(self):
        total_frames = len(self.cond_images)
270
271
        self.move_frames = self.config["target_video_length"] - self.config["refert_num"]
        if total_frames <= self.config["target_video_length"]:
272
273
            self.video_segment_num = 1
        else:
274
            self.video_segment_num = 1 + (total_frames - self.config["target_video_length"] + self.move_frames - 1) // self.move_frames
275
276
277
278
279
280

    def init_run(self):
        self.all_out_frames = []
        self.prepare_input()
        super().init_run()

yihuiwen's avatar
yihuiwen committed
281
282
283
284
285
286
    @ProfilingContext4DebugL1(
        "Run VAE Decoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
        metrics_labels=["WanAnimateRunner"],
    )
287
    def run_vae_decoder(self, latents):
288
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
289
290
            self.vae_decoder = self.load_vae_decoder()
        images = self.vae_decoder.decode(latents[:, 1:].to(GET_DTYPE()))
291
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
292
293
294
295
296
            del self.vae_decoder
            torch.cuda.empty_cache()
            gc.collect()
        return images

297
298
299
300
301
302
    @ProfilingContext4DebugL1(
        "Init run segment",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_init_run_segment_duration,
        metrics_labels=["WanAnimateRunner"],
    )
303
304
    def init_run_segment(self, segment_idx):
        start = segment_idx * self.move_frames
305
        end = start + self.config["target_video_length"]
306
307
308
        if start == 0:
            self.mask_reft_len = 0
        else:
309
            self.mask_reft_len = self.config["refert_num"]
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

        conditioning_pixel_values = torch.tensor(
            np.stack(self.cond_images[start:end]) / 127.5 - 1,
            device="cuda",
            dtype=GET_DTYPE(),
        ).permute(3, 0, 1, 2)  # c t h w

        face_pixel_values = torch.tensor(
            np.stack(self.face_images[start:end]) / 127.5 - 1,
            device="cuda",
            dtype=GET_DTYPE(),
        ).permute(0, 3, 1, 2)  # thwc->tchw

        if start == 0:
            height, width = self.refer_images.shape[:2]
            refer_t_pixel_values = torch.zeros(
                3,
327
                self.config["refert_num"],
328
329
330
331
332
333
                height,
                width,
                device="cuda",
                dtype=GET_DTYPE(),
            )  # c t h w
        else:
334
            refer_t_pixel_values = self.gen_video[0, :, -self.config["refert_num"] :].transpose(0, 1).clone().detach().to(AI_DEVICE)  # c t h w
335
336

        bg_pixel_values, mask_pixel_values = None, None
337
        if self.config["replace_flag"] if "replace_flag" in self.config else False:
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
            bg_pixel_values = torch.tensor(
                np.stack(self.bg_images[start:end]) / 127.5 - 1,
                device="cuda",
                dtype=GET_DTYPE(),
            ).permute(3, 0, 1, 2)  # c t h w,

            mask_pixel_values = torch.tensor(
                np.stack(self.mask_images[start:end])[:, :, :, None],
                device="cuda",
                dtype=GET_DTYPE(),
            ).permute(3, 0, 1, 2)  # c t h w,

        self.inputs.update(
            self.run_image_encoders(
                conditioning_pixel_values,
                refer_t_pixel_values,
                bg_pixel_values,
                mask_pixel_values,
                face_pixel_values,
            )
        )

        if start != 0:
361
            self.model.scheduler.reset(self.input_info.seed, self.input_info.latent_shape)
362
363
364
365
366
367

    def end_run_segment(self, segment_idx):
        if segment_idx != 0:
            self.gen_video = self.gen_video[:, :, self.config["refert_num"] :]
        self.all_out_frames.append(self.gen_video.cpu())

368
369
    def process_images_after_vae_decoder(self):
        self.gen_video_final = torch.cat(self.all_out_frames, dim=2)[:, :, : self.real_frame_len]
370
371
        del self.all_out_frames
        gc.collect()
372
        super().process_images_after_vae_decoder()
373

yihuiwen's avatar
yihuiwen committed
374
375
376
377
378
379
    @ProfilingContext4DebugL1(
        "Run Image Encoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_img_encode_duration,
        metrics_labels=["WanAnimateRunner"],
    )
380
    def run_image_encoder(self, img):  # CHW
381
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
382
383
            self.image_encoder = self.load_image_encoder()
        clip_encoder_out = self.image_encoder.visual([img.unsqueeze(0)]).squeeze(0).to(GET_DTYPE())
384
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
385
386
387
388
389
390
391
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
        return clip_encoder_out

    def load_transformer(self):
        model = WanAnimateModel(
392
            self.config["model_path"],
393
394
395
            self.config,
            self.init_device,
        )
396
397
398
399
400
401
402
403
404
405
406

        if self.config.get("lora_configs") and self.config.lora_configs:
            assert not self.config.get("dit_quantized", False)
            lora_wrapper = WanLoraWrapper(model)
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
                lora_name = lora_wrapper.load_lora(lora_path)
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")

gushiqiao's avatar
gushiqiao committed
407
        motion_encoder, face_encoder = self.load_encoders()
408
409
410
        model.set_animate_encoders(motion_encoder, face_encoder)
        return model

gushiqiao's avatar
gushiqiao committed
411
    def load_encoders(self):
412
413
        motion_encoder = Generator(size=512, style_dim=512, motion_dim=20).eval().requires_grad_(False).to(GET_DTYPE()).to(AI_DEVICE)
        face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4).eval().requires_grad_(False).to(GET_DTYPE()).to(AI_DEVICE)
414
415
416
417
418
        motion_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["motion_encoder"]), "motion_encoder.")
        face_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["face_encoder"]), "face_encoder.")
        motion_encoder.load_state_dict(motion_weight_dict)
        face_encoder.load_state_dict(face_weight_dict)
        return motion_encoder, face_encoder