audio_model.py 7.92 KB
Newer Older
1
2
import os

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
3
import torch
4
import torch.distributed as dist
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
5
from loguru import logger
6

7
8
9
from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer
wangshankun's avatar
wangshankun committed
10
from lightx2v.models.networks.wan.model import WanModel
11
from lightx2v.models.networks.wan.weights.audio.transformer_weights import WanAudioTransformerWeights
wangshankun's avatar
wangshankun committed
12
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
13
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
14
from lightx2v.utils.utils import load_weights
wangshankun's avatar
wangshankun committed
15
16
17
18
19


class WanAudioModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
20
    transformer_weight_class = WanAudioTransformerWeights
wangshankun's avatar
wangshankun committed
21

helloyongyang's avatar
helloyongyang committed
22
    def __init__(self, model_path, config, device):
23
24
        self.config = config
        self._load_adapter_ckpt()
helloyongyang's avatar
helloyongyang committed
25
        super().__init__(model_path, config, device)
wangshankun's avatar
wangshankun committed
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
    def _load_adapter_ckpt(self):
        if self.config.get("adapter_model_path", None) is None:
            if self.config.get("adapter_quantized", False):
                if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f"]:
                    adapter_model_name = "audio_adapter_model_fp8.safetensors"
                elif self.config.get("adapter_quant_scheme", None) == "int8":
                    adapter_model_name = "audio_adapter_model_int8.safetensors"
                else:
                    raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
            else:
                adapter_model_name = "audio_adapter_model.safetensors"
            self.config.adapter_model_path = os.path.join(self.config.model_path, adapter_model_name)

        adapter_offload = self.config.get("cpu_offload", False)
41
42
        load_from_rank0 = self.config.get("load_from_rank0", False)
        self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio", load_from_rank0=load_from_rank0)
43
44
45
46
        if not adapter_offload:
            if not dist.is_initialized() or not load_from_rank0:
                for key in self.adapter_weights_dict:
                    self.adapter_weights_dict[key] = self.adapter_weights_dict[key].cuda()
47

wangshankun's avatar
wangshankun committed
48
    def _init_infer_class(self):
helloyongyang's avatar
helloyongyang committed
49
        super()._init_infer_class()
wangshankun's avatar
wangshankun committed
50
51
        self.pre_infer_class = WanAudioPreInfer
        self.post_infer_class = WanAudioPostInfer
52
        self.transformer_infer_class = WanAudioTransformerInfer
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
53

sandy's avatar
sandy committed
54
55
    def get_graph_name(self, shape, audio_num):
        return f"graph_{shape[0]}x{shape[1]}_{audio_num}audio"
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
56

sandy's avatar
sandy committed
57
58
59
    def start_compile(self, shape, audio_num):
        graph_name = self.get_graph_name(shape, audio_num)
        logger.info(f"[Compile] Compile shape: {shape}, audio_num:{audio_num}, graph_name: {graph_name}")
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

        target_video_length = self.config.get("target_video_length", 81)
        latents_length = (target_video_length - 1) // 16 * 4 + 1
        latents_h = shape[0] // self.config.vae_stride[1]
        latents_w = shape[1] // self.config.vae_stride[2]

        new_inputs = {}
        new_inputs["text_encoder_output"] = {}
        new_inputs["text_encoder_output"]["context"] = torch.randn(1, 512, 4096, dtype=torch.bfloat16).cuda()
        new_inputs["text_encoder_output"]["context_null"] = torch.randn(1, 512, 4096, dtype=torch.bfloat16).cuda()

        new_inputs["image_encoder_output"] = {}
        new_inputs["image_encoder_output"]["clip_encoder_out"] = torch.randn(257, 1280, dtype=torch.bfloat16).cuda()
        new_inputs["image_encoder_output"]["vae_encoder_out"] = torch.randn(16, 1, latents_h, latents_w, dtype=torch.bfloat16).cuda()

sandy's avatar
sandy committed
75
76
        new_inputs["audio_encoder_output"] = torch.randn(audio_num, latents_length, 128, 1024, dtype=torch.bfloat16).cuda()
        new_inputs["person_mask_latens"] = torch.zeros(audio_num, 1, (latents_h // 2), (latents_w // 2), dtype=torch.int8).cuda()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

        new_inputs["previmg_encoder_output"] = {}
        new_inputs["previmg_encoder_output"]["prev_latents"] = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()
        new_inputs["previmg_encoder_output"]["prev_mask"] = torch.randn(4, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()

        self.scheduler.latents = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()
        self.scheduler.timestep_input = torch.tensor([600.0], dtype=torch.float32).cuda()
        self.scheduler.audio_adapter_t_emb = torch.randn(1, 3, 5120, dtype=torch.bfloat16).cuda()

        self._infer_cond_uncond(new_inputs, infer_condition=True, graph_name=graph_name)

    def compile(self, compile_shapes):
        self.check_compile_shapes(compile_shapes)
        self.enable_compile_mode("_infer_cond_uncond")

        if self.cpu_offload:
93
            if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config.model_cls:
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
94
95
96
97
98
                self.to_cuda()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cuda()
                self.transformer_weights.non_block_weights_to_cuda()

sandy's avatar
sandy committed
99
100
101
102
        max_audio_num_num = self.config.get("compile_max_audios", 1)
        for audio_num in range(1, max_audio_num_num + 1):
            for shape in compile_shapes:
                self.start_compile(shape, audio_num)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
103
104

        if self.cpu_offload:
105
            if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config.model_cls:
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
106
107
108
109
110
111
112
113
114
115
116
117
118
                self.to_cpu()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cpu()
                self.transformer_weights.non_block_weights_to_cpu()

        self.disable_compile_mode("_infer_cond_uncond")
        logger.info(f"[Compile] Compile status: {self.get_compile_status()}")

    def check_compile_shapes(self, compile_shapes):
        for shape in compile_shapes:
            assert shape in [[480, 832], [544, 960], [720, 1280], [832, 480], [960, 544], [1280, 720], [480, 480], [576, 576], [704, 704], [960, 960]]

    def select_graph_for_compile(self):
sandy's avatar
sandy committed
119
120
        logger.info(f"tgt_h, tgt_w : {self.config.get('tgt_h')}, {self.config.get('tgt_w')}, audio_num: {self.config.get('audio_num')}")
        self.select_graph("_infer_cond_uncond", f"graph_{self.config.get('tgt_h')}x{self.config.get('tgt_w')}_{self.config.get('audio_num')}audio")
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
121
        logger.info(f"[Compile] Compile status: {self.get_compile_status()}")
sandy's avatar
sandy committed
122
123
124
125

    @torch.no_grad()
    def _seq_parallel_pre_process(self, pre_infer_out):
        x = pre_infer_out.x
gushiqiao's avatar
gushiqiao committed
126
        person_mask_latens = pre_infer_out.adapter_args["person_mask_latens"]
sandy's avatar
sandy committed
127
128
129
130
131
132
133
134
135
136
137
138

        world_size = dist.get_world_size(self.seq_p_group)
        cur_rank = dist.get_rank(self.seq_p_group)

        padding_size = (world_size - (x.shape[0] % world_size)) % world_size
        if padding_size > 0:
            x = F.pad(x, (0, 0, 0, padding_size))
            if person_mask_latens is not None:
                person_mask_latens = F.pad(person_mask_latens, (0, padding_size))

        pre_infer_out.x = torch.chunk(x, world_size, dim=0)[cur_rank]
        if person_mask_latens is not None:
gushiqiao's avatar
gushiqiao committed
139
            pre_infer_out.adapter_args["person_mask_latens"] = torch.chunk(person_mask_latens, world_size, dim=1)[cur_rank]
sandy's avatar
sandy committed
140
141
142
143
144
145
146
147
148
149

        if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] == "i2v":
            embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0
            padding_size = (world_size - (embed.shape[0] % world_size)) % world_size
            if padding_size > 0:
                embed = F.pad(embed, (0, 0, 0, padding_size))
                embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size))
            pre_infer_out.embed = torch.chunk(embed, world_size, dim=0)[cur_rank]
            pre_infer_out.embed0 = torch.chunk(embed0, world_size, dim=0)[cur_rank]
        return pre_infer_out