audio_model.py 5.39 KB
Newer Older
PengGao's avatar
PengGao committed
1
import glob
wangshankun's avatar
wangshankun committed
2
import os
PengGao's avatar
PengGao committed
3
4
5
6
7
8

import torch

from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer
wangshankun's avatar
wangshankun committed
9
10
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
11
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
wangshankun's avatar
wangshankun committed
12
13
14
15
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)

wangshankun's avatar
wangshankun committed
16
17
from loguru import logger

wangshankun's avatar
wangshankun committed
18
19
20
21
22
23
24
25
26
27

class WanAudioModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)

    def _init_infer_class(self):
helloyongyang's avatar
helloyongyang committed
28
        super()._init_infer_class()
wangshankun's avatar
wangshankun committed
29
30
31
32
33
34
35
36
37
        self.pre_infer_class = WanAudioPreInfer
        self.post_infer_class = WanAudioPostInfer

    @torch.no_grad()
    def infer(self, inputs):
        if self.config["cpu_offload"]:
            self.pre_weight.to_cuda()
            self.post_weight.to_cuda()

wangshankun's avatar
wangshankun committed
38
39
40
41
42
43
        if self.transformer_infer.mask_map is None:
            _, c, h, w = self.scheduler.latents.shape
            num_frame = c + 1  # for r2v
            video_token_num = num_frame * (h // 2) * (w // 2)
            self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)

wangshankun's avatar
wangshankun committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
        noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]

        if self.config["feature_caching"] == "Tea":
            self.scheduler.cnt += 1
            if self.scheduler.cnt >= self.scheduler.num_steps:
                self.scheduler.cnt = 0
        self.scheduler.noise_pred = noise_pred_cond

        if self.config["enable_cfg"]:
            embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
            x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
            noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]

            if self.config["feature_caching"] == "Tea":
                self.scheduler.cnt += 1
                if self.scheduler.cnt >= self.scheduler.num_steps:
                    self.scheduler.cnt = 0

wangshankun's avatar
wangshankun committed
64
            self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
wangshankun's avatar
wangshankun committed
65
66
67
68

            if self.config["cpu_offload"]:
                self.pre_weight.to_cpu()
                self.post_weight.to_cpu()
wangshankun's avatar
wangshankun committed
69

wangshankun's avatar
wangshankun committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    @torch.no_grad()
    def infer_wo_cfg_parallel(self, inputs):
        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == 0:
                self.to_cuda()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cuda()
                self.post_weight.to_cuda()

        if self.transformer_infer.mask_map is None:
            _, c, h, w = self.scheduler.latents.shape
            num_frame = c + 1  # for r2v
            video_token_num = num_frame * (h // 2) * (w // 2)
            self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)

        embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
        noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]

        self.scheduler.noise_pred = noise_pred_cond

        if self.clean_cuda_cache:
            del x, embed, pre_infer_out, noise_pred_cond, grid_sizes
            torch.cuda.empty_cache()

        if self.config["enable_cfg"]:
            embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
            x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
            noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]

            self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)

            if self.clean_cuda_cache:
                del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
                torch.cuda.empty_cache()

        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
                self.to_cpu()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cpu()
                self.post_weight.to_cpu()

wangshankun's avatar
wangshankun committed
113
114

class Wan22MoeAudioModel(WanAudioModel):
115
    def _load_ckpt(self, unified_dtype, sensitive_layer):
wangshankun's avatar
wangshankun committed
116
117
118
        safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
        weight_dict = {}
        for file_path in safetensors_files:
119
            file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
wangshankun's avatar
wangshankun committed
120
121
            weight_dict.update(file_weights)
        return weight_dict