wan_vace_runner.py 8.52 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
2
3
4
5
6
7
8
9
10
import gc

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image

from lightx2v.models.input_encoders.hf.vace.vace_processor import VaceVideoProcessor
from lightx2v.models.networks.wan.vace_model import WanVaceModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
11
from lightx2v.utils.envs import *
gushiqiao's avatar
gushiqiao committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER


@RUNNER_REGISTER("wan2.1_vace")
class WanVaceRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)
        assert self.config.task == "vace"
        self.vid_proc = VaceVideoProcessor(
            downsample=tuple([x * y for x, y in zip(self.config.vae_stride, self.config.patch_size)]),
            min_area=720 * 1280,
            max_area=720 * 1280,
            min_fps=self.config.get("fps", 16),
            max_fps=self.config.get("fps", 16),
            zero_start=True,
            seq_len=75600,
            keep_last=True,
        )

    def load_transformer(self):
        model = WanVaceModel(
            self.config.model_path,
            self.config,
            self.init_device,
        )
        return model

    def prepare_source(self, src_video, src_mask, src_ref_images, image_size, device=torch.device("cuda")):
        area = image_size[0] * image_size[1]
        self.vid_proc.set_area(area)
        if area == 720 * 1280:
            self.vid_proc.set_seq_len(75600)
        elif area == 480 * 832:
            self.vid_proc.set_seq_len(32760)
        else:
            raise NotImplementedError(f"image_size {image_size} is not supported")

        image_size = (image_size[1], image_size[0])
        image_sizes = []
        for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
            if sub_src_mask is not None and sub_src_video is not None:
                src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask)
                src_video[i] = src_video[i].to(device)
                src_mask[i] = src_mask[i].to(device)
                src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
                image_sizes.append(src_video[i].shape[2:])
            elif sub_src_video is None:
                src_video[i] = torch.zeros((3, self.config.target_video_length, image_size[0], image_size[1]), device=device)
                src_mask[i] = torch.ones_like(src_video[i], device=device)
                image_sizes.append(image_size)
            else:
                src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video)
                src_video[i] = src_video[i].to(device)
                src_mask[i] = torch.ones_like(src_video[i], device=device)
                image_sizes.append(src_video[i].shape[2:])

        for i, ref_images in enumerate(src_ref_images):
            if ref_images is not None:
                image_size = image_sizes[i]
                for j, ref_img in enumerate(ref_images):
                    if ref_img is not None:
                        ref_img = Image.open(ref_img).convert("RGB")
                        ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
                        if ref_img.shape[-2:] != image_size:
                            canvas_height, canvas_width = image_size
                            ref_height, ref_width = ref_img.shape[-2:]
                            white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device)  # [-1, 1]
                            scale = min(canvas_height / ref_height, canvas_width / ref_width)
                            new_height = int(ref_height * scale)
                            new_width = int(ref_width * scale)
                            resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode="bilinear", align_corners=False).squeeze(0).unsqueeze(1)
                            top = (canvas_height - new_height) // 2
                            left = (canvas_width - new_width) // 2
                            white_canvas[:, :, top : top + new_height, left : left + new_width] = resized_image
                            ref_img = white_canvas
                        src_ref_images[i][j] = ref_img.to(device)
        return src_video, src_mask, src_ref_images

    def run_vae_encoder(self, frames, ref_images, masks):
92
93
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae_encoder = self.load_vae_encoder()
gushiqiao's avatar
gushiqiao committed
94
95
96
97
98
99
        if ref_images is None:
            ref_images = [None] * len(frames)
        else:
            assert len(frames) == len(ref_images)

        if masks is None:
100
            latents = [self.vae_encoder.encode(frame.unsqueeze(0).to(GET_DTYPE())) for frame in frames]
gushiqiao's avatar
gushiqiao committed
101
102
103
104
        else:
            masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks]
            inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
            reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
105
106
            inactive = [self.vae_encoder.encode(inact.unsqueeze(0).to(GET_DTYPE())) for inact in inactive]
            reactive = [self.vae_encoder.encode(react.unsqueeze(0).to(GET_DTYPE())) for react in reactive]
gushiqiao's avatar
gushiqiao committed
107
108
109
110
111
112
            latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]

        cat_latents = []
        for latent, refs in zip(latents, ref_images):
            if refs is not None:
                if masks is None:
113
                    ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0).to(GET_DTYPE())) for ref in refs]
gushiqiao's avatar
gushiqiao committed
114
                else:
115
                    ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0).to(GET_DTYPE())) for ref in refs]
gushiqiao's avatar
gushiqiao committed
116
117
118
119
120
                    ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
                assert all([x.shape[1] == 1 for x in ref_latent])
                latent = torch.cat([*ref_latent, latent], dim=1)
            cat_latents.append(latent)
        self.latent_shape = list(cat_latents[0].shape)
121
122
123
124
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae_encoder
            torch.cuda.empty_cache()
            gc.collect()
gushiqiao's avatar
gushiqiao committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        return self.get_vae_encoder_output(cat_latents, masks, ref_images)

    def get_vae_encoder_output(self, cat_latents, masks, ref_images):
        if ref_images is None:
            ref_images = [None] * len(masks)
        else:
            assert len(masks) == len(ref_images)

        result_masks = []
        for mask, refs in zip(masks, ref_images):
            c, depth, height, width = mask.shape
            new_depth = int((depth + 3) // self.config.vae_stride[0])
            height = 2 * (int(height) // (self.config.vae_stride[1] * 2))
            width = 2 * (int(width) // (self.config.vae_stride[2] * 2))

            # reshape
            mask = mask[0, :, :, :]
            mask = mask.view(depth, height, self.config.vae_stride[1], width, self.config.vae_stride[1])  # depth, height, 8, width, 8
            mask = mask.permute(2, 4, 0, 1, 3)  # 8, 8, depth, height, width
            mask = mask.reshape(self.config.vae_stride[1] * self.config.vae_stride[2], depth, height, width)  # 8*8, depth, height, width

            # interpolation
            mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode="nearest-exact").squeeze(0)

            if refs is not None:
                length = len(refs)
                mask_pad = torch.zeros_like(mask[:, :length, :, :])
                mask = torch.cat((mask_pad, mask), dim=1)
            result_masks.append(mask)

        return [torch.cat([zz, mm], dim=0) for zz, mm in zip(cat_latents, result_masks)]

    def set_target_shape(self):
        target_shape = self.latent_shape
        target_shape[0] = int(target_shape[0] / 2)
        self.config.target_shape = target_shape

    @ProfilingContext("Run VAE Decoder")
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
163
    def run_vae_decoder(self, latents):
gushiqiao's avatar
gushiqiao committed
164
165
166
167
168
169
170
171
172
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae_decoder = self.load_vae_decoder()

        if self.src_ref_images is not None:
            assert len(self.src_ref_images) == 1
            refs = self.src_ref_images[0]
            if refs is not None:
                latents = latents[:, len(refs) :, :, :]

173
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
174
175
176
177
178
179
180

        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae_decoder
            torch.cuda.empty_cache()
            gc.collect()

        return images