audio_model.py 7.92 KB
Newer Older
1
2
import os

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
3
import torch
4
import torch.distributed as dist
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
5
from loguru import logger
6

7
8
9
from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer
wangshankun's avatar
wangshankun committed
10
from lightx2v.models.networks.wan.model import WanModel
11
from lightx2v.models.networks.wan.weights.audio.transformer_weights import WanAudioTransformerWeights
wangshankun's avatar
wangshankun committed
12
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
13
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
14
from lightx2v.utils.utils import load_weights
wangshankun's avatar
wangshankun committed
15
16
17
18
19


class WanAudioModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
20
    transformer_weight_class = WanAudioTransformerWeights
wangshankun's avatar
wangshankun committed
21

helloyongyang's avatar
helloyongyang committed
22
    def __init__(self, model_path, config, device):
23
24
        self.config = config
        self._load_adapter_ckpt()
helloyongyang's avatar
helloyongyang committed
25
        super().__init__(model_path, config, device)
wangshankun's avatar
wangshankun committed
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
    def _load_adapter_ckpt(self):
        if self.config.get("adapter_model_path", None) is None:
            if self.config.get("adapter_quantized", False):
                if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f"]:
                    adapter_model_name = "audio_adapter_model_fp8.safetensors"
                elif self.config.get("adapter_quant_scheme", None) == "int8":
                    adapter_model_name = "audio_adapter_model_int8.safetensors"
                else:
                    raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
            else:
                adapter_model_name = "audio_adapter_model.safetensors"
            self.config.adapter_model_path = os.path.join(self.config.model_path, adapter_model_name)

        adapter_offload = self.config.get("cpu_offload", False)
41
42
        load_from_rank0 = self.config.get("load_from_rank0", False)
        self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio", load_from_rank0=load_from_rank0)
43
44
45
46
        if not adapter_offload:
            if not dist.is_initialized() or not load_from_rank0:
                for key in self.adapter_weights_dict:
                    self.adapter_weights_dict[key] = self.adapter_weights_dict[key].cuda()
47

wangshankun's avatar
wangshankun committed
48
    def _init_infer_class(self):
helloyongyang's avatar
helloyongyang committed
49
        super()._init_infer_class()
wangshankun's avatar
wangshankun committed
50
51
        self.pre_infer_class = WanAudioPreInfer
        self.post_infer_class = WanAudioPostInfer
52
        self.transformer_infer_class = WanAudioTransformerInfer
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
53

sandy's avatar
sandy committed
54
55
    def get_graph_name(self, shape, audio_num):
        return f"graph_{shape[0]}x{shape[1]}_{audio_num}audio"
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
56

sandy's avatar
sandy committed
57
58
59
    def start_compile(self, shape, audio_num):
        graph_name = self.get_graph_name(shape, audio_num)
        logger.info(f"[Compile] Compile shape: {shape}, audio_num:{audio_num}, graph_name: {graph_name}")
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

        target_video_length = self.config.get("target_video_length", 81)
        latents_length = (target_video_length - 1) // 16 * 4 + 1
        latents_h = shape[0] // self.config.vae_stride[1]
        latents_w = shape[1] // self.config.vae_stride[2]

        new_inputs = {}
        new_inputs["text_encoder_output"] = {}
        new_inputs["text_encoder_output"]["context"] = torch.randn(1, 512, 4096, dtype=torch.bfloat16).cuda()
        new_inputs["text_encoder_output"]["context_null"] = torch.randn(1, 512, 4096, dtype=torch.bfloat16).cuda()

        new_inputs["image_encoder_output"] = {}
        new_inputs["image_encoder_output"]["clip_encoder_out"] = torch.randn(257, 1280, dtype=torch.bfloat16).cuda()
        new_inputs["image_encoder_output"]["vae_encoder_out"] = torch.randn(16, 1, latents_h, latents_w, dtype=torch.bfloat16).cuda()

sandy's avatar
sandy committed
75
76
        new_inputs["audio_encoder_output"] = torch.randn(audio_num, latents_length, 128, 1024, dtype=torch.bfloat16).cuda()
        new_inputs["person_mask_latens"] = torch.zeros(audio_num, 1, (latents_h // 2), (latents_w // 2), dtype=torch.int8).cuda()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

        new_inputs["previmg_encoder_output"] = {}
        new_inputs["previmg_encoder_output"]["prev_latents"] = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()
        new_inputs["previmg_encoder_output"]["prev_mask"] = torch.randn(4, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()

        self.scheduler.latents = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()
        self.scheduler.timestep_input = torch.tensor([600.0], dtype=torch.float32).cuda()
        self.scheduler.audio_adapter_t_emb = torch.randn(1, 3, 5120, dtype=torch.bfloat16).cuda()

        self._infer_cond_uncond(new_inputs, infer_condition=True, graph_name=graph_name)

    def compile(self, compile_shapes):
        self.check_compile_shapes(compile_shapes)
        self.enable_compile_mode("_infer_cond_uncond")

        if self.cpu_offload:
93
            if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config.model_cls:
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
94
95
96
97
98
                self.to_cuda()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cuda()
                self.transformer_weights.non_block_weights_to_cuda()

sandy's avatar
sandy committed
99
100
101
102
        max_audio_num_num = self.config.get("compile_max_audios", 1)
        for audio_num in range(1, max_audio_num_num + 1):
            for shape in compile_shapes:
                self.start_compile(shape, audio_num)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
103
104

        if self.cpu_offload:
105
            if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config.model_cls:
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
106
107
108
109
110
111
112
113
114
115
116
117
118
                self.to_cpu()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cpu()
                self.transformer_weights.non_block_weights_to_cpu()

        self.disable_compile_mode("_infer_cond_uncond")
        logger.info(f"[Compile] Compile status: {self.get_compile_status()}")

    def check_compile_shapes(self, compile_shapes):
        for shape in compile_shapes:
            assert shape in [[480, 832], [544, 960], [720, 1280], [832, 480], [960, 544], [1280, 720], [480, 480], [576, 576], [704, 704], [960, 960]]

    def select_graph_for_compile(self):
sandy's avatar
sandy committed
119
120
        logger.info(f"tgt_h, tgt_w : {self.config.get('tgt_h')}, {self.config.get('tgt_w')}, audio_num: {self.config.get('audio_num')}")
        self.select_graph("_infer_cond_uncond", f"graph_{self.config.get('tgt_h')}x{self.config.get('tgt_w')}_{self.config.get('audio_num')}audio")
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
121
        logger.info(f"[Compile] Compile status: {self.get_compile_status()}")
sandy's avatar
sandy committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    @torch.no_grad()
    def _seq_parallel_pre_process(self, pre_infer_out):
        x = pre_infer_out.x
        person_mask_latens = pre_infer_out.adapter_output["person_mask_latens"]

        world_size = dist.get_world_size(self.seq_p_group)
        cur_rank = dist.get_rank(self.seq_p_group)

        padding_size = (world_size - (x.shape[0] % world_size)) % world_size
        if padding_size > 0:
            x = F.pad(x, (0, 0, 0, padding_size))
            if person_mask_latens is not None:
                person_mask_latens = F.pad(person_mask_latens, (0, padding_size))

        pre_infer_out.x = torch.chunk(x, world_size, dim=0)[cur_rank]
        if person_mask_latens is not None:
            pre_infer_out.adapter_output["person_mask_latens"] = torch.chunk(person_mask_latens, world_size, dim=1)[cur_rank]

        if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] == "i2v":
            embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0
            padding_size = (world_size - (embed.shape[0] % world_size)) % world_size
            if padding_size > 0:
                embed = F.pad(embed, (0, 0, 0, padding_size))
                embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size))
            pre_infer_out.embed = torch.chunk(embed, world_size, dim=0)[cur_rank]
            pre_infer_out.embed0 = torch.chunk(embed0, world_size, dim=0)[cur_rank]
        return pre_infer_out