wan_animate_runner.py 17.1 KB
Newer Older
1
2
3
4
5
6
7
import gc
from copy import deepcopy

import cv2
import numpy as np
import torch
import torch.nn.functional as F
gushiqiao's avatar
gushiqiao committed
8
9
10
11
12
13
from loguru import logger

try:
    from decord import VideoReader
except ImportError:
    VideoReader = None
14
    logger.info("If you want to run animate model, please install decord.")
gushiqiao's avatar
gushiqiao committed
15

16
17
18
19
20
21
22
23
24

from lightx2v.models.input_encoders.hf.animate.face_encoder import FaceEncoder
from lightx2v.models.input_encoders.hf.animate.motion_encoder import Generator
from lightx2v.models.networks.wan.animate_model import WanAnimateModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import load_weights, remove_substrings_from_keys
yihuiwen's avatar
yihuiwen committed
25
from lightx2v.server.metrics import monitor_cli
26
27
28
29
30
31


@RUNNER_REGISTER("wan2.2_animate")
class WanAnimateRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)
32
        assert self.config["task"] == "animate"
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

    def inputs_padding(self, array, target_len):
        idx = 0
        flip = False
        target_array = []
        while len(target_array) < target_len:
            target_array.append(deepcopy(array[idx]))
            if flip:
                idx -= 1
            else:
                idx += 1
            if idx == 0 or idx == len(array) - 1:
                flip = not flip
        return target_array[:target_len]

    def get_valid_len(self, real_len, clip_len=81, overlap=1):
        real_clip_len = clip_len - overlap
        last_clip_num = (real_len - overlap) % real_clip_len
        if last_clip_num == 0:
            extra = 0
        else:
            extra = real_clip_len - last_clip_num
        target_len = real_len + extra
        return target_len

    def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
        if mask_pixel_values is None:
            msk = torch.zeros(1, (lat_t - 1) * 4 + 1, lat_h, lat_w, dtype=GET_DTYPE(), device=device)
        else:
            msk = mask_pixel_values.clone()
        msk[:, :mask_len] = 1
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
        return msk

    def padding_resize(
        self,
        img_ori,
        height=512,
        width=512,
        padding_color=(0, 0, 0),
        interpolation=cv2.INTER_LINEAR,
    ):
        ori_height = img_ori.shape[0]
        ori_width = img_ori.shape[1]
        channel = img_ori.shape[2]

        img_pad = np.zeros((height, width, channel))
        if channel == 1:
            img_pad[:, :, 0] = padding_color[0]
        else:
            img_pad[:, :, 0] = padding_color[0]
            img_pad[:, :, 1] = padding_color[1]
            img_pad[:, :, 2] = padding_color[2]

        if (ori_height / ori_width) > (height / width):
            new_width = int(height / ori_height * ori_width)
            img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)
            padding = int((width - new_width) / 2)
            if len(img.shape) == 2:
                img = img[:, :, np.newaxis]
            img_pad[:, padding : padding + new_width, :] = img
        else:
            new_height = int(width / ori_width * ori_height)
            img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)
            padding = int((height - new_height) / 2)
            if len(img.shape) == 2:
                img = img[:, :, np.newaxis]
            img_pad[padding : padding + new_height, :, :] = img

        img_pad = np.uint8(img_pad)

        return img_pad

    def prepare_source(self, src_pose_path, src_face_path, src_ref_path):
        pose_video_reader = VideoReader(src_pose_path)
        pose_len = len(pose_video_reader)
        pose_idxs = list(range(pose_len))
        cond_images = pose_video_reader.get_batch(pose_idxs).asnumpy()

        face_video_reader = VideoReader(src_face_path)
        face_len = len(face_video_reader)
        face_idxs = list(range(face_len))
        face_images = face_video_reader.get_batch(face_idxs).asnumpy()
        height, width = cond_images[0].shape[:2]
        refer_images = cv2.imread(src_ref_path)[..., ::-1]
        refer_images = self.padding_resize(refer_images, height=height, width=width)
        return cond_images, face_images, refer_images

    def prepare_source_for_replace(self, src_bg_path, src_mask_path):
        bg_video_reader = VideoReader(src_bg_path)
        bg_len = len(bg_video_reader)
        bg_idxs = list(range(bg_len))
        bg_images = bg_video_reader.get_batch(bg_idxs).asnumpy()

        mask_video_reader = VideoReader(src_mask_path)
        mask_len = len(mask_video_reader)
        mask_idxs = list(range(mask_len))
        mask_images = mask_video_reader.get_batch(mask_idxs).asnumpy()
        mask_images = mask_images[:, :, :, 0] / 255
        return bg_images, mask_images

    @ProfilingContext4DebugL2("Run Image Encoders")
    def run_image_encoders(
        self,
        conditioning_pixel_values,
        refer_t_pixel_values,
        bg_pixel_values,
        mask_pixel_values,
        face_pixel_values,
    ):
        clip_encoder_out = self.run_image_encoder(self.refer_pixel_values)
        vae_encoder_out, pose_latents = self.run_vae_encoder(
            conditioning_pixel_values,
            refer_t_pixel_values,
            bg_pixel_values,
            mask_pixel_values,
        )
        return {"image_encoder_output": {"clip_encoder_out": clip_encoder_out, "vae_encoder_out": vae_encoder_out, "pose_latents": pose_latents, "face_pixel_values": face_pixel_values}}

yihuiwen's avatar
yihuiwen committed
154
155
156
157
158
159
    @ProfilingContext4DebugL1(
        "Run VAE Encoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_vae_encode_duration,
        metrics_labels=["WanAnimateRunner"],
    )
160
161
162
163
164
165
166
167
168
169
170
    def run_vae_encoder(
        self,
        conditioning_pixel_values,
        refer_t_pixel_values,
        bg_pixel_values,
        mask_pixel_values,
    ):
        H, W = self.refer_pixel_values.shape[-2], self.refer_pixel_values.shape[-1]
        pose_latents = self.vae_encoder.encode(conditioning_pixel_values.unsqueeze(0))  #  c t h w
        ref_latents = self.vae_encoder.encode(self.refer_pixel_values.unsqueeze(1).unsqueeze(0))  #  c t h w

171
        mask_ref = self.get_i2v_mask(1, self.latent_h, self.latent_w, 1)
172
173
174
        y_ref = torch.concat([mask_ref, ref_latents])

        if self.mask_reft_len > 0:
175
            if self.config["replace_flag"]:
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
                y_reft = self.vae_encoder.encode(
                    torch.concat(
                        [
                            refer_t_pixel_values.unsqueeze(2)[0, :, : self.mask_reft_len],
                            bg_pixel_values[:, self.mask_reft_len :],
                        ],
                        dim=1,
                    )
                    .cuda()
                    .unsqueeze(0)
                )
                mask_pixel_values = 1 - mask_pixel_values
                mask_pixel_values = mask_pixel_values.permute(1, 0, 2, 3)
                mask_pixel_values = F.interpolate(mask_pixel_values, size=(H // 8, W // 8), mode="nearest")
                mask_pixel_values = mask_pixel_values[:, 0, :, :]

                msk_reft = self.get_i2v_mask(
193
194
195
                    self.latent_t,
                    self.latent_h,
                    self.latent_w,
196
197
198
199
200
201
202
203
204
205
206
207
                    self.mask_reft_len,
                    mask_pixel_values=mask_pixel_values.unsqueeze(0),
                )
            else:
                y_reft = self.vae_encoder.encode(
                    torch.concat(
                        [
                            torch.nn.functional.interpolate(
                                refer_t_pixel_values.unsqueeze(2)[0, :, : self.mask_reft_len].cpu(),
                                size=(H, W),
                                mode="bicubic",
                            ),
208
                            torch.zeros(3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE()),
209
210
211
212
213
214
                        ],
                        dim=1,
                    )
                    .cuda()
                    .unsqueeze(0)
                )
215
                msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len)
216
        else:
217
            if self.config["replace_flag"]:
218
219
220
221
222
223
                mask_pixel_values = 1 - mask_pixel_values
                mask_pixel_values = mask_pixel_values.permute(1, 0, 2, 3)
                mask_pixel_values = F.interpolate(mask_pixel_values, size=(H // 8, W // 8), mode="nearest")
                mask_pixel_values = mask_pixel_values[:, 0, :, :]
                y_reft = self.vae_encoder.encode(bg_pixel_values.unsqueeze(0))
                msk_reft = self.get_i2v_mask(
224
225
226
                    self.latent_t,
                    self.latent_h,
                    self.latent_w,
227
228
229
230
                    self.mask_reft_len,
                    mask_pixel_values=mask_pixel_values.unsqueeze(0),
                )
            else:
231
232
                y_reft = self.vae_encoder.encode(torch.zeros(1, 3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE(), device="cuda"))
                msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len)
233
234
235
236
237
238
239

        y_reft = torch.concat([msk_reft, y_reft])
        y = torch.concat([y_ref, y_reft], dim=1)

        return y, pose_latents

    def prepare_input(self):
240
241
242
        src_pose_path = self.config["src_pose_path"] if "src_pose_path" in self.config else None
        src_face_path = self.config["src_face_path"] if "src_face_path" in self.config else None
        src_ref_path = self.config["src_ref_images"] if "src_ref_images" in self.config else None
243
244
        self.cond_images, self.face_images, self.refer_images = self.prepare_source(src_pose_path, src_face_path, src_ref_path)
        self.refer_pixel_values = torch.tensor(self.refer_images / 127.5 - 1, dtype=GET_DTYPE(), device="cuda").permute(2, 0, 1)  # chw
245
246
247
248
        self.latent_t = self.config["target_video_length"] // self.config["vae_stride"][0] + 1
        self.latent_h = self.refer_pixel_values.shape[-2] // self.config["vae_stride"][1]
        self.latent_w = self.refer_pixel_values.shape[-1] // self.config["vae_stride"][2]
        self.input_info.latent_shape = [self.config.get("num_channels_latents", 16), self.latent_t + 1, self.latent_h, self.latent_w]
249
250
251
        self.real_frame_len = len(self.cond_images)
        target_len = self.get_valid_len(
            self.real_frame_len,
252
253
            self.config["target_video_length"],
            overlap=self.config["refert_num"] if "refert_num" in self.config else 1,
254
255
256
257
258
        )
        logger.info("real frames: {} target frames: {}".format(self.real_frame_len, target_len))
        self.cond_images = self.inputs_padding(self.cond_images, target_len)
        self.face_images = self.inputs_padding(self.face_images, target_len)

259
260
261
        if self.config["replace_flag"] if "replace_flag" in self.config else False:
            src_bg_path = self.config["src_bg_path"]
            src_mask_path = self.config["src_mask_path"]
262
263
264
265
266
267
            self.bg_images, self.mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path)
            self.bg_images = self.inputs_padding(self.bg_images, target_len)
            self.mask_images = self.inputs_padding(self.mask_images, target_len)

    def get_video_segment_num(self):
        total_frames = len(self.cond_images)
268
269
        self.move_frames = self.config["target_video_length"] - self.config["refert_num"]
        if total_frames <= self.config["target_video_length"]:
270
271
            self.video_segment_num = 1
        else:
272
            self.video_segment_num = 1 + (total_frames - self.config["target_video_length"] + self.move_frames - 1) // self.move_frames
273
274
275
276
277
278

    def init_run(self):
        self.all_out_frames = []
        self.prepare_input()
        super().init_run()

yihuiwen's avatar
yihuiwen committed
279
280
281
282
283
284
    @ProfilingContext4DebugL1(
        "Run VAE Decoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
        metrics_labels=["WanAnimateRunner"],
    )
285
    def run_vae_decoder(self, latents):
286
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
287
288
            self.vae_decoder = self.load_vae_decoder()
        images = self.vae_decoder.decode(latents[:, 1:].to(GET_DTYPE()))
289
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
290
291
292
293
294
295
296
            del self.vae_decoder
            torch.cuda.empty_cache()
            gc.collect()
        return images

    def init_run_segment(self, segment_idx):
        start = segment_idx * self.move_frames
297
        end = start + self.config["target_video_length"]
298
299
300
        if start == 0:
            self.mask_reft_len = 0
        else:
301
            self.mask_reft_len = self.config["refert_num"]
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

        conditioning_pixel_values = torch.tensor(
            np.stack(self.cond_images[start:end]) / 127.5 - 1,
            device="cuda",
            dtype=GET_DTYPE(),
        ).permute(3, 0, 1, 2)  # c t h w

        face_pixel_values = torch.tensor(
            np.stack(self.face_images[start:end]) / 127.5 - 1,
            device="cuda",
            dtype=GET_DTYPE(),
        ).permute(0, 3, 1, 2)  # thwc->tchw

        if start == 0:
            height, width = self.refer_images.shape[:2]
            refer_t_pixel_values = torch.zeros(
                3,
319
                self.config["refert_num"],
320
321
322
323
324
325
                height,
                width,
                device="cuda",
                dtype=GET_DTYPE(),
            )  # c t h w
        else:
326
            refer_t_pixel_values = self.gen_video[0, :, -self.config["refert_num"] :].transpose(0, 1).clone().detach().cuda()  # c t h w
327
328

        bg_pixel_values, mask_pixel_values = None, None
329
        if self.config["replace_flag"] if "replace_flag" in self.config else False:
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
            bg_pixel_values = torch.tensor(
                np.stack(self.bg_images[start:end]) / 127.5 - 1,
                device="cuda",
                dtype=GET_DTYPE(),
            ).permute(3, 0, 1, 2)  # c t h w,

            mask_pixel_values = torch.tensor(
                np.stack(self.mask_images[start:end])[:, :, :, None],
                device="cuda",
                dtype=GET_DTYPE(),
            ).permute(3, 0, 1, 2)  # c t h w,

        self.inputs.update(
            self.run_image_encoders(
                conditioning_pixel_values,
                refer_t_pixel_values,
                bg_pixel_values,
                mask_pixel_values,
                face_pixel_values,
            )
        )

        if start != 0:
353
            self.model.scheduler.reset(self.input_info.seed, self.input_info.latent_shape)
354
355
356
357
358
359

    def end_run_segment(self, segment_idx):
        if segment_idx != 0:
            self.gen_video = self.gen_video[:, :, self.config["refert_num"] :]
        self.all_out_frames.append(self.gen_video.cpu())

360
361
    def process_images_after_vae_decoder(self):
        self.gen_video_final = torch.cat(self.all_out_frames, dim=2)[:, :, : self.real_frame_len]
362
363
        del self.all_out_frames
        gc.collect()
364
        super().process_images_after_vae_decoder()
365

yihuiwen's avatar
yihuiwen committed
366
367
368
369
370
371
    @ProfilingContext4DebugL1(
        "Run Image Encoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_img_encode_duration,
        metrics_labels=["WanAnimateRunner"],
    )
372
    def run_image_encoder(self, img):  # CHW
373
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
374
375
            self.image_encoder = self.load_image_encoder()
        clip_encoder_out = self.image_encoder.visual([img.unsqueeze(0)]).squeeze(0).to(GET_DTYPE())
376
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
377
378
379
380
381
382
383
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
        return clip_encoder_out

    def load_transformer(self):
        model = WanAnimateModel(
384
            self.config["model_path"],
385
386
387
            self.config,
            self.init_device,
        )
gushiqiao's avatar
gushiqiao committed
388
        motion_encoder, face_encoder = self.load_encoders()
389
390
391
        model.set_animate_encoders(motion_encoder, face_encoder)
        return model

gushiqiao's avatar
gushiqiao committed
392
393
394
    def load_encoders(self):
        motion_encoder = Generator(size=512, style_dim=512, motion_dim=20).eval().requires_grad_(False).to(GET_DTYPE()).cuda()
        face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4).eval().requires_grad_(False).to(GET_DTYPE()).cuda()
395
396
397
398
399
        motion_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["motion_encoder"]), "motion_encoder.")
        face_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["face_encoder"]), "face_encoder.")
        motion_encoder.load_state_dict(motion_weight_dict)
        face_encoder.load_state_dict(face_weight_dict)
        return motion_encoder, face_encoder