vace_model.py 2.07 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.vace.transformer_infer import WanVaceTransformerInfer
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.vace.transformer_weights import (
    WanVaceTransformerWeights,
)
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *


class WanVaceModel(WanModel):
    pre_weight_class = WanPreWeights
    transformer_weight_class = WanVaceTransformerWeights

    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)

22
23
24
25
26
27
28
29
    def _init_infer(self):
        super()._init_infer()
        if hasattr(self.transformer_infer, "offload_manager"):
            self.transformer_infer.offload_block_buffers = self.transformer_weights.offload_block_buffers
            self.transformer_infer.offload_phase_buffers = self.transformer_weights.offload_phase_buffers
            self.transformer_infer.vace_offload_block_buffers = self.transformer_weights.vace_offload_block_buffers
            self.transformer_infer.vace_offload_phase_buffers = self.transformer_weights.vace_offload_phase_buffers

gushiqiao's avatar
gushiqiao committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    def _init_infer_class(self):
        self.pre_infer_class = WanPreInfer
        self.post_infer_class = WanPostInfer
        self.transformer_infer_class = WanVaceTransformerInfer

    @torch.no_grad()
    def _infer_cond_uncond(self, inputs, infer_condition=True):
        self.scheduler.infer_condition = infer_condition

        pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs)
        pre_infer_out.vace_context = inputs["image_encoder_output"]["vae_encoder_out"][0]

        x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out)

        noise_pred = self.post_infer.infer(x, pre_infer_out)[0]

        if self.clean_cuda_cache:
            del x, pre_infer_out
            torch.cuda.empty_cache()

        return noise_pred