wan_animate_runner.py 16.5 KB
Newer Older
1
2
3
4
5
6
7
import gc
from copy import deepcopy

import cv2
import numpy as np
import torch
import torch.nn.functional as F
gushiqiao's avatar
gushiqiao committed
8
9
10
11
12
13
from loguru import logger

try:
    from decord import VideoReader
except ImportError:
    VideoReader = None
14
    logger.info("If you want to run animate model, please install decord.")
gushiqiao's avatar
gushiqiao committed
15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

from lightx2v.models.input_encoders.hf.animate.face_encoder import FaceEncoder
from lightx2v.models.input_encoders.hf.animate.motion_encoder import Generator
from lightx2v.models.networks.wan.animate_model import WanAnimateModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import load_weights, remove_substrings_from_keys


@RUNNER_REGISTER("wan2.2_animate")
class WanAnimateRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)
31
        assert self.config["task"] == "animate"
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

    def inputs_padding(self, array, target_len):
        idx = 0
        flip = False
        target_array = []
        while len(target_array) < target_len:
            target_array.append(deepcopy(array[idx]))
            if flip:
                idx -= 1
            else:
                idx += 1
            if idx == 0 or idx == len(array) - 1:
                flip = not flip
        return target_array[:target_len]

    def get_valid_len(self, real_len, clip_len=81, overlap=1):
        real_clip_len = clip_len - overlap
        last_clip_num = (real_len - overlap) % real_clip_len
        if last_clip_num == 0:
            extra = 0
        else:
            extra = real_clip_len - last_clip_num
        target_len = real_len + extra
        return target_len

    def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
        if mask_pixel_values is None:
            msk = torch.zeros(1, (lat_t - 1) * 4 + 1, lat_h, lat_w, dtype=GET_DTYPE(), device=device)
        else:
            msk = mask_pixel_values.clone()
        msk[:, :mask_len] = 1
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]
        return msk

    def padding_resize(
        self,
        img_ori,
        height=512,
        width=512,
        padding_color=(0, 0, 0),
        interpolation=cv2.INTER_LINEAR,
    ):
        ori_height = img_ori.shape[0]
        ori_width = img_ori.shape[1]
        channel = img_ori.shape[2]

        img_pad = np.zeros((height, width, channel))
        if channel == 1:
            img_pad[:, :, 0] = padding_color[0]
        else:
            img_pad[:, :, 0] = padding_color[0]
            img_pad[:, :, 1] = padding_color[1]
            img_pad[:, :, 2] = padding_color[2]

        if (ori_height / ori_width) > (height / width):
            new_width = int(height / ori_height * ori_width)
            img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)
            padding = int((width - new_width) / 2)
            if len(img.shape) == 2:
                img = img[:, :, np.newaxis]
            img_pad[:, padding : padding + new_width, :] = img
        else:
            new_height = int(width / ori_width * ori_height)
            img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)
            padding = int((height - new_height) / 2)
            if len(img.shape) == 2:
                img = img[:, :, np.newaxis]
            img_pad[padding : padding + new_height, :, :] = img

        img_pad = np.uint8(img_pad)

        return img_pad

    def prepare_source(self, src_pose_path, src_face_path, src_ref_path):
        pose_video_reader = VideoReader(src_pose_path)
        pose_len = len(pose_video_reader)
        pose_idxs = list(range(pose_len))
        cond_images = pose_video_reader.get_batch(pose_idxs).asnumpy()

        face_video_reader = VideoReader(src_face_path)
        face_len = len(face_video_reader)
        face_idxs = list(range(face_len))
        face_images = face_video_reader.get_batch(face_idxs).asnumpy()
        height, width = cond_images[0].shape[:2]
        refer_images = cv2.imread(src_ref_path)[..., ::-1]
        refer_images = self.padding_resize(refer_images, height=height, width=width)
        return cond_images, face_images, refer_images

    def prepare_source_for_replace(self, src_bg_path, src_mask_path):
        bg_video_reader = VideoReader(src_bg_path)
        bg_len = len(bg_video_reader)
        bg_idxs = list(range(bg_len))
        bg_images = bg_video_reader.get_batch(bg_idxs).asnumpy()

        mask_video_reader = VideoReader(src_mask_path)
        mask_len = len(mask_video_reader)
        mask_idxs = list(range(mask_len))
        mask_images = mask_video_reader.get_batch(mask_idxs).asnumpy()
        mask_images = mask_images[:, :, :, 0] / 255
        return bg_images, mask_images

    @ProfilingContext4DebugL2("Run Image Encoders")
    def run_image_encoders(
        self,
        conditioning_pixel_values,
        refer_t_pixel_values,
        bg_pixel_values,
        mask_pixel_values,
        face_pixel_values,
    ):
        clip_encoder_out = self.run_image_encoder(self.refer_pixel_values)
        vae_encoder_out, pose_latents = self.run_vae_encoder(
            conditioning_pixel_values,
            refer_t_pixel_values,
            bg_pixel_values,
            mask_pixel_values,
        )
        return {"image_encoder_output": {"clip_encoder_out": clip_encoder_out, "vae_encoder_out": vae_encoder_out, "pose_latents": pose_latents, "face_pixel_values": face_pixel_values}}

    def run_vae_encoder(
        self,
        conditioning_pixel_values,
        refer_t_pixel_values,
        bg_pixel_values,
        mask_pixel_values,
    ):
        H, W = self.refer_pixel_values.shape[-2], self.refer_pixel_values.shape[-1]
        pose_latents = self.vae_encoder.encode(conditioning_pixel_values.unsqueeze(0))  #  c t h w
        ref_latents = self.vae_encoder.encode(self.refer_pixel_values.unsqueeze(1).unsqueeze(0))  #  c t h w

164
        mask_ref = self.get_i2v_mask(1, self.latent_h, self.latent_w, 1)
165
166
167
        y_ref = torch.concat([mask_ref, ref_latents])

        if self.mask_reft_len > 0:
168
            if self.config["replace_flag"]:
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
                y_reft = self.vae_encoder.encode(
                    torch.concat(
                        [
                            refer_t_pixel_values.unsqueeze(2)[0, :, : self.mask_reft_len],
                            bg_pixel_values[:, self.mask_reft_len :],
                        ],
                        dim=1,
                    )
                    .cuda()
                    .unsqueeze(0)
                )
                mask_pixel_values = 1 - mask_pixel_values
                mask_pixel_values = mask_pixel_values.permute(1, 0, 2, 3)
                mask_pixel_values = F.interpolate(mask_pixel_values, size=(H // 8, W // 8), mode="nearest")
                mask_pixel_values = mask_pixel_values[:, 0, :, :]

                msk_reft = self.get_i2v_mask(
186
187
188
                    self.latent_t,
                    self.latent_h,
                    self.latent_w,
189
190
191
192
193
194
195
196
197
198
199
200
                    self.mask_reft_len,
                    mask_pixel_values=mask_pixel_values.unsqueeze(0),
                )
            else:
                y_reft = self.vae_encoder.encode(
                    torch.concat(
                        [
                            torch.nn.functional.interpolate(
                                refer_t_pixel_values.unsqueeze(2)[0, :, : self.mask_reft_len].cpu(),
                                size=(H, W),
                                mode="bicubic",
                            ),
201
                            torch.zeros(3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE()),
202
203
204
205
206
207
                        ],
                        dim=1,
                    )
                    .cuda()
                    .unsqueeze(0)
                )
208
                msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len)
209
        else:
210
            if self.config["replace_flag"]:
211
212
213
214
215
216
                mask_pixel_values = 1 - mask_pixel_values
                mask_pixel_values = mask_pixel_values.permute(1, 0, 2, 3)
                mask_pixel_values = F.interpolate(mask_pixel_values, size=(H // 8, W // 8), mode="nearest")
                mask_pixel_values = mask_pixel_values[:, 0, :, :]
                y_reft = self.vae_encoder.encode(bg_pixel_values.unsqueeze(0))
                msk_reft = self.get_i2v_mask(
217
218
219
                    self.latent_t,
                    self.latent_h,
                    self.latent_w,
220
221
222
223
                    self.mask_reft_len,
                    mask_pixel_values=mask_pixel_values.unsqueeze(0),
                )
            else:
224
225
                y_reft = self.vae_encoder.encode(torch.zeros(1, 3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE(), device="cuda"))
                msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len)
226
227
228
229
230
231
232

        y_reft = torch.concat([msk_reft, y_reft])
        y = torch.concat([y_ref, y_reft], dim=1)

        return y, pose_latents

    def prepare_input(self):
233
234
235
        src_pose_path = self.config["src_pose_path"] if "src_pose_path" in self.config else None
        src_face_path = self.config["src_face_path"] if "src_face_path" in self.config else None
        src_ref_path = self.config["src_ref_images"] if "src_ref_images" in self.config else None
236
237
        self.cond_images, self.face_images, self.refer_images = self.prepare_source(src_pose_path, src_face_path, src_ref_path)
        self.refer_pixel_values = torch.tensor(self.refer_images / 127.5 - 1, dtype=GET_DTYPE(), device="cuda").permute(2, 0, 1)  # chw
238
239
240
241
        self.latent_t = self.config["target_video_length"] // self.config["vae_stride"][0] + 1
        self.latent_h = self.refer_pixel_values.shape[-2] // self.config["vae_stride"][1]
        self.latent_w = self.refer_pixel_values.shape[-1] // self.config["vae_stride"][2]
        self.input_info.latent_shape = [self.config.get("num_channels_latents", 16), self.latent_t + 1, self.latent_h, self.latent_w]
242
243
244
        self.real_frame_len = len(self.cond_images)
        target_len = self.get_valid_len(
            self.real_frame_len,
245
246
            self.config["target_video_length"],
            overlap=self.config["refert_num"] if "refert_num" in self.config else 1,
247
248
249
250
251
        )
        logger.info("real frames: {} target frames: {}".format(self.real_frame_len, target_len))
        self.cond_images = self.inputs_padding(self.cond_images, target_len)
        self.face_images = self.inputs_padding(self.face_images, target_len)

252
253
254
        if self.config["replace_flag"] if "replace_flag" in self.config else False:
            src_bg_path = self.config["src_bg_path"]
            src_mask_path = self.config["src_mask_path"]
255
256
257
258
259
260
            self.bg_images, self.mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path)
            self.bg_images = self.inputs_padding(self.bg_images, target_len)
            self.mask_images = self.inputs_padding(self.mask_images, target_len)

    def get_video_segment_num(self):
        total_frames = len(self.cond_images)
261
262
        self.move_frames = self.config["target_video_length"] - self.config["refert_num"]
        if total_frames <= self.config["target_video_length"]:
263
264
            self.video_segment_num = 1
        else:
265
            self.video_segment_num = 1 + (total_frames - self.config["target_video_length"] + self.move_frames - 1) // self.move_frames
266
267
268
269
270
271
272
273

    def init_run(self):
        self.all_out_frames = []
        self.prepare_input()
        super().init_run()

    @ProfilingContext4DebugL1("Run VAE Decoder")
    def run_vae_decoder(self, latents):
274
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
275
276
            self.vae_decoder = self.load_vae_decoder()
        images = self.vae_decoder.decode(latents[:, 1:].to(GET_DTYPE()))
277
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
278
279
280
281
282
283
284
            del self.vae_decoder
            torch.cuda.empty_cache()
            gc.collect()
        return images

    def init_run_segment(self, segment_idx):
        start = segment_idx * self.move_frames
285
        end = start + self.config["target_video_length"]
286
287
288
        if start == 0:
            self.mask_reft_len = 0
        else:
289
            self.mask_reft_len = self.config["refert_num"]
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306

        conditioning_pixel_values = torch.tensor(
            np.stack(self.cond_images[start:end]) / 127.5 - 1,
            device="cuda",
            dtype=GET_DTYPE(),
        ).permute(3, 0, 1, 2)  # c t h w

        face_pixel_values = torch.tensor(
            np.stack(self.face_images[start:end]) / 127.5 - 1,
            device="cuda",
            dtype=GET_DTYPE(),
        ).permute(0, 3, 1, 2)  # thwc->tchw

        if start == 0:
            height, width = self.refer_images.shape[:2]
            refer_t_pixel_values = torch.zeros(
                3,
307
                self.config["refert_num"],
308
309
310
311
312
313
                height,
                width,
                device="cuda",
                dtype=GET_DTYPE(),
            )  # c t h w
        else:
314
            refer_t_pixel_values = self.gen_video[0, :, -self.config["refert_num"] :].transpose(0, 1).clone().detach().cuda()  # c t h w
315
316

        bg_pixel_values, mask_pixel_values = None, None
317
        if self.config["replace_flag"] if "replace_flag" in self.config else False:
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
            bg_pixel_values = torch.tensor(
                np.stack(self.bg_images[start:end]) / 127.5 - 1,
                device="cuda",
                dtype=GET_DTYPE(),
            ).permute(3, 0, 1, 2)  # c t h w,

            mask_pixel_values = torch.tensor(
                np.stack(self.mask_images[start:end])[:, :, :, None],
                device="cuda",
                dtype=GET_DTYPE(),
            ).permute(3, 0, 1, 2)  # c t h w,

        self.inputs.update(
            self.run_image_encoders(
                conditioning_pixel_values,
                refer_t_pixel_values,
                bg_pixel_values,
                mask_pixel_values,
                face_pixel_values,
            )
        )

        if start != 0:
341
            self.model.scheduler.reset(self.input_info.seed, self.input_info.latent_shape)
342
343
344
345
346
347

    def end_run_segment(self, segment_idx):
        if segment_idx != 0:
            self.gen_video = self.gen_video[:, :, self.config["refert_num"] :]
        self.all_out_frames.append(self.gen_video.cpu())

348
349
    def process_images_after_vae_decoder(self):
        self.gen_video_final = torch.cat(self.all_out_frames, dim=2)[:, :, : self.real_frame_len]
350
351
        del self.all_out_frames
        gc.collect()
352
        super().process_images_after_vae_decoder()
353
354

    def run_image_encoder(self, img):  # CHW
355
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
356
357
            self.image_encoder = self.load_image_encoder()
        clip_encoder_out = self.image_encoder.visual([img.unsqueeze(0)]).squeeze(0).to(GET_DTYPE())
358
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
359
360
361
362
363
364
365
            del self.image_encoder
            torch.cuda.empty_cache()
            gc.collect()
        return clip_encoder_out

    def load_transformer(self):
        model = WanAnimateModel(
366
            self.config["model_path"],
367
368
369
            self.config,
            self.init_device,
        )
gushiqiao's avatar
gushiqiao committed
370
        motion_encoder, face_encoder = self.load_encoders()
371
372
373
        model.set_animate_encoders(motion_encoder, face_encoder)
        return model

gushiqiao's avatar
gushiqiao committed
374
375
376
    def load_encoders(self):
        motion_encoder = Generator(size=512, style_dim=512, motion_dim=20).eval().requires_grad_(False).to(GET_DTYPE()).cuda()
        face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4).eval().requires_grad_(False).to(GET_DTYPE()).cuda()
377
378
379
380
381
        motion_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["motion_encoder"]), "motion_encoder.")
        face_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["face_encoder"]), "face_encoder.")
        motion_encoder.load_state_dict(motion_weight_dict)
        face_encoder.load_state_dict(face_weight_dict)
        return motion_encoder, face_encoder