audio_model.py 5.36 KB
Newer Older
PengGao's avatar
PengGao committed
1
import glob
wangshankun's avatar
wangshankun committed
2
import os
PengGao's avatar
PengGao committed
3
4
5
6
7
8

import torch

from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer
wangshankun's avatar
wangshankun committed
9
10
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
11
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
wangshankun's avatar
wangshankun committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)


class WanAudioModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)

    def _init_infer_class(self):
helloyongyang's avatar
helloyongyang committed
26
        super()._init_infer_class()
wangshankun's avatar
wangshankun committed
27
28
29
30
31
32
33
34
35
        self.pre_infer_class = WanAudioPreInfer
        self.post_infer_class = WanAudioPostInfer

    @torch.no_grad()
    def infer(self, inputs):
        if self.config["cpu_offload"]:
            self.pre_weight.to_cuda()
            self.post_weight.to_cuda()

wangshankun's avatar
wangshankun committed
36
37
38
39
40
41
        if self.transformer_infer.mask_map is None:
            _, c, h, w = self.scheduler.latents.shape
            num_frame = c + 1  # for r2v
            video_token_num = num_frame * (h // 2) * (w // 2)
            self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)

wangshankun's avatar
wangshankun committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
        noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]

        if self.config["feature_caching"] == "Tea":
            self.scheduler.cnt += 1
            if self.scheduler.cnt >= self.scheduler.num_steps:
                self.scheduler.cnt = 0
        self.scheduler.noise_pred = noise_pred_cond

        if self.config["enable_cfg"]:
            embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
            x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
            noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]

            if self.config["feature_caching"] == "Tea":
                self.scheduler.cnt += 1
                if self.scheduler.cnt >= self.scheduler.num_steps:
                    self.scheduler.cnt = 0

wangshankun's avatar
wangshankun committed
62
            self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
wangshankun's avatar
wangshankun committed
63
64
65
66

            if self.config["cpu_offload"]:
                self.pre_weight.to_cpu()
                self.post_weight.to_cpu()
wangshankun's avatar
wangshankun committed
67

wangshankun's avatar
wangshankun committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    @torch.no_grad()
    def infer_wo_cfg_parallel(self, inputs):
        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == 0:
                self.to_cuda()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cuda()
                self.post_weight.to_cuda()

        if self.transformer_infer.mask_map is None:
            _, c, h, w = self.scheduler.latents.shape
            num_frame = c + 1  # for r2v
            video_token_num = num_frame * (h // 2) * (w // 2)
            self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)

        embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
        x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
        noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]

        self.scheduler.noise_pred = noise_pred_cond

        if self.clean_cuda_cache:
            del x, embed, pre_infer_out, noise_pred_cond, grid_sizes
            torch.cuda.empty_cache()

        if self.config["enable_cfg"]:
            embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
            x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
            noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]

            self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)

            if self.clean_cuda_cache:
                del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
                torch.cuda.empty_cache()

        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
                self.to_cpu()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cpu()
                self.post_weight.to_cpu()

wangshankun's avatar
wangshankun committed
111
112

class Wan22MoeAudioModel(WanAudioModel):
113
    def _load_ckpt(self, unified_dtype, sensitive_layer):
wangshankun's avatar
wangshankun committed
114
115
116
        safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
        weight_dict = {}
        for file_path in safetensors_files:
117
            file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
wangshankun's avatar
wangshankun committed
118
119
            weight_dict.update(file_weights)
        return weight_dict