wan_vace_runner.py 8.89 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
2
3
4
5
6
7
8
9
10
import gc

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image

from lightx2v.models.input_encoders.hf.vace.vace_processor import VaceVideoProcessor
from lightx2v.models.networks.wan.vace_model import WanVaceModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
11
from lightx2v.utils.envs import *
12
from lightx2v.utils.profiler import *
gushiqiao's avatar
gushiqiao committed
13
14
15
16
17
18
19
from lightx2v.utils.registry_factory import RUNNER_REGISTER


@RUNNER_REGISTER("wan2.1_vace")
class WanVaceRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)
20
        assert self.config["task"] == "vace"
gushiqiao's avatar
gushiqiao committed
21
        self.vid_proc = VaceVideoProcessor(
22
            downsample=tuple([x * y for x, y in zip(self.config["vae_stride"], self.config["patch_size"])]),
gushiqiao's avatar
gushiqiao committed
23
24
            min_area=720 * 1280,
            max_area=720 * 1280,
25
26
            min_fps=self.config["fps"] if "fps" in self.config else 16,
            max_fps=self.config["fps"] if "fps" in self.config else 16,
gushiqiao's avatar
gushiqiao committed
27
28
29
30
31
32
33
            zero_start=True,
            seq_len=75600,
            keep_last=True,
        )

    def load_transformer(self):
        model = WanVaceModel(
34
            self.config["model_path"],
gushiqiao's avatar
gushiqiao committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
            self.config,
            self.init_device,
        )
        return model

    def prepare_source(self, src_video, src_mask, src_ref_images, image_size, device=torch.device("cuda")):
        area = image_size[0] * image_size[1]
        self.vid_proc.set_area(area)
        if area == 720 * 1280:
            self.vid_proc.set_seq_len(75600)
        elif area == 480 * 832:
            self.vid_proc.set_seq_len(32760)
        else:
            raise NotImplementedError(f"image_size {image_size} is not supported")

        image_size = (image_size[1], image_size[0])
        image_sizes = []
        for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
            if sub_src_mask is not None and sub_src_video is not None:
                src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask)
                src_video[i] = src_video[i].to(device)
                src_mask[i] = src_mask[i].to(device)
                src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
                image_sizes.append(src_video[i].shape[2:])
            elif sub_src_video is None:
60
                src_video[i] = torch.zeros((3, self.config["target_video_length"], image_size[0], image_size[1]), device=device)
gushiqiao's avatar
gushiqiao committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
                src_mask[i] = torch.ones_like(src_video[i], device=device)
                image_sizes.append(image_size)
            else:
                src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video)
                src_video[i] = src_video[i].to(device)
                src_mask[i] = torch.ones_like(src_video[i], device=device)
                image_sizes.append(src_video[i].shape[2:])

        for i, ref_images in enumerate(src_ref_images):
            if ref_images is not None:
                image_size = image_sizes[i]
                for j, ref_img in enumerate(ref_images):
                    if ref_img is not None:
                        ref_img = Image.open(ref_img).convert("RGB")
                        ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
                        if ref_img.shape[-2:] != image_size:
                            canvas_height, canvas_width = image_size
                            ref_height, ref_width = ref_img.shape[-2:]
                            white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device)  # [-1, 1]
                            scale = min(canvas_height / ref_height, canvas_width / ref_width)
                            new_height = int(ref_height * scale)
                            new_width = int(ref_width * scale)
                            resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode="bilinear", align_corners=False).squeeze(0).unsqueeze(1)
                            top = (canvas_height - new_height) // 2
                            left = (canvas_width - new_width) // 2
                            white_canvas[:, :, top : top + new_height, left : left + new_width] = resized_image
                            ref_img = white_canvas
                        src_ref_images[i][j] = ref_img.to(device)
        return src_video, src_mask, src_ref_images

    def run_vae_encoder(self, frames, ref_images, masks):
92
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
93
            self.vae_encoder = self.load_vae_encoder()
gushiqiao's avatar
gushiqiao committed
94
95
96
97
98
99
        if ref_images is None:
            ref_images = [None] * len(frames)
        else:
            assert len(frames) == len(ref_images)

        if masks is None:
100
            latents = [self.vae_encoder.encode(frame.unsqueeze(0).to(GET_DTYPE())) for frame in frames]
gushiqiao's avatar
gushiqiao committed
101
102
103
104
        else:
            masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks]
            inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
            reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
105
106
            inactive = [self.vae_encoder.encode(inact.unsqueeze(0).to(GET_DTYPE())) for inact in inactive]
            reactive = [self.vae_encoder.encode(react.unsqueeze(0).to(GET_DTYPE())) for react in reactive]
gushiqiao's avatar
gushiqiao committed
107
108
109
110
111
112
            latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]

        cat_latents = []
        for latent, refs in zip(latents, ref_images):
            if refs is not None:
                if masks is None:
113
                    ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0).to(GET_DTYPE())) for ref in refs]
gushiqiao's avatar
gushiqiao committed
114
                else:
115
                    ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0).to(GET_DTYPE())) for ref in refs]
gushiqiao's avatar
gushiqiao committed
116
117
118
119
120
                    ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
                assert all([x.shape[1] == 1 for x in ref_latent])
                latent = torch.cat([*ref_latent, latent], dim=1)
            cat_latents.append(latent)
        self.latent_shape = list(cat_latents[0].shape)
121
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
122
123
124
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
125
        return self.get_vae_encoder_output(cat_latents, masks, ref_images), self.set_input_info_latent_shape()
gushiqiao's avatar
gushiqiao committed
126
127
128
129
130
131
132
133
134
135

    def get_vae_encoder_output(self, cat_latents, masks, ref_images):
        if ref_images is None:
            ref_images = [None] * len(masks)
        else:
            assert len(masks) == len(ref_images)

        result_masks = []
        for mask, refs in zip(masks, ref_images):
            c, depth, height, width = mask.shape
136
137
138
            new_depth = int((depth + 3) // self.config["vae_stride"][0])
            height = 2 * (int(height) // (self.config["vae_stride"][1] * 2))
            width = 2 * (int(width) // (self.config["vae_stride"][2] * 2))
gushiqiao's avatar
gushiqiao committed
139
140
141

            # reshape
            mask = mask[0, :, :, :]
142
            mask = mask.view(depth, height, self.config["vae_stride"][1], width, self.config["vae_stride"][1])  # depth, height, 8, width, 8
gushiqiao's avatar
gushiqiao committed
143
            mask = mask.permute(2, 4, 0, 1, 3)  # 8, 8, depth, height, width
144
            mask = mask.reshape(self.config["vae_stride"][1] * self.config["vae_stride"][2], depth, height, width)  # 8*8, depth, height, width
gushiqiao's avatar
gushiqiao committed
145
146
147
148
149
150
151
152
153
154
155
156

            # interpolation
            mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode="nearest-exact").squeeze(0)

            if refs is not None:
                length = len(refs)
                mask_pad = torch.zeros_like(mask[:, :length, :, :])
                mask = torch.cat((mask_pad, mask), dim=1)
            result_masks.append(mask)

        return [torch.cat([zz, mm], dim=0) for zz, mm in zip(cat_latents, result_masks)]

157
158
159
160
    def set_input_info_latent_shape(self):
        latent_shape = self.latent_shape
        latent_shape[0] = int(latent_shape[0] / 2)
        return latent_shape
gushiqiao's avatar
gushiqiao committed
161

162
    @ProfilingContext4DebugL1("Run VAE Decoder")
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
163
    def run_vae_decoder(self, latents):
164
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
gushiqiao's avatar
gushiqiao committed
165
166
167
168
169
170
171
172
            self.vae_decoder = self.load_vae_decoder()

        if self.src_ref_images is not None:
            assert len(self.src_ref_images) == 1
            refs = self.src_ref_images[0]
            if refs is not None:
                latents = latents[:, len(refs) :, :, :]

173
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
174

175
        if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
gushiqiao's avatar
gushiqiao committed
176
177
178
179
180
            del self.vae_decoder
            torch.cuda.empty_cache()
            gc.collect()

        return images