audio_model.py 6 KB
Newer Older
1
2
import os

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
3
import torch
4
import torch.distributed as dist
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
5
from loguru import logger
6

7
8
9
from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer
wangshankun's avatar
wangshankun committed
10
from lightx2v.models.networks.wan.model import WanModel
11
from lightx2v.models.networks.wan.weights.audio.transformer_weights import WanAudioTransformerWeights
wangshankun's avatar
wangshankun committed
12
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
13
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
14
from lightx2v.utils.utils import load_weights
wangshankun's avatar
wangshankun committed
15
16
17
18
19


class WanAudioModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
20
    transformer_weight_class = WanAudioTransformerWeights
wangshankun's avatar
wangshankun committed
21

helloyongyang's avatar
helloyongyang committed
22
    def __init__(self, model_path, config, device):
23
24
        self.config = config
        self._load_adapter_ckpt()
helloyongyang's avatar
helloyongyang committed
25
        super().__init__(model_path, config, device)
wangshankun's avatar
wangshankun committed
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
    def _load_adapter_ckpt(self):
        if self.config.get("adapter_model_path", None) is None:
            if self.config.get("adapter_quantized", False):
                if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f"]:
                    adapter_model_name = "audio_adapter_model_fp8.safetensors"
                elif self.config.get("adapter_quant_scheme", None) == "int8":
                    adapter_model_name = "audio_adapter_model_int8.safetensors"
                else:
                    raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
            else:
                adapter_model_name = "audio_adapter_model.safetensors"
            self.config.adapter_model_path = os.path.join(self.config.model_path, adapter_model_name)

        adapter_offload = self.config.get("cpu_offload", False)
41
42
        load_from_rank0 = self.config.get("load_from_rank0", False)
        self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio", load_from_rank0=load_from_rank0)
43
44
45
46
        if not adapter_offload:
            if not dist.is_initialized() or not load_from_rank0:
                for key in self.adapter_weights_dict:
                    self.adapter_weights_dict[key] = self.adapter_weights_dict[key].cuda()
47

wangshankun's avatar
wangshankun committed
48
    def _init_infer_class(self):
helloyongyang's avatar
helloyongyang committed
49
        super()._init_infer_class()
wangshankun's avatar
wangshankun committed
50
51
        self.pre_infer_class = WanAudioPreInfer
        self.post_infer_class = WanAudioPostInfer
52
        self.transformer_infer_class = WanAudioTransformerInfer
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    def get_graph_name(self, shape):
        return f"graph_{shape[0]}x{shape[1]}"

    def start_compile(self, shape):
        graph_name = self.get_graph_name(shape)
        logger.info(f"[Compile] Compile shape: {shape}, graph_name: {graph_name}")

        target_video_length = self.config.get("target_video_length", 81)
        latents_length = (target_video_length - 1) // 16 * 4 + 1
        latents_h = shape[0] // self.config.vae_stride[1]
        latents_w = shape[1] // self.config.vae_stride[2]

        new_inputs = {}
        new_inputs["text_encoder_output"] = {}
        new_inputs["text_encoder_output"]["context"] = torch.randn(1, 512, 4096, dtype=torch.bfloat16).cuda()
        new_inputs["text_encoder_output"]["context_null"] = torch.randn(1, 512, 4096, dtype=torch.bfloat16).cuda()

        new_inputs["image_encoder_output"] = {}
        new_inputs["image_encoder_output"]["clip_encoder_out"] = torch.randn(257, 1280, dtype=torch.bfloat16).cuda()
        new_inputs["image_encoder_output"]["vae_encoder_out"] = torch.randn(16, 1, latents_h, latents_w, dtype=torch.bfloat16).cuda()

        new_inputs["audio_encoder_output"] = torch.randn(1, latents_length, 128, 1024, dtype=torch.bfloat16).cuda()

        new_inputs["previmg_encoder_output"] = {}
        new_inputs["previmg_encoder_output"]["prev_latents"] = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()
        new_inputs["previmg_encoder_output"]["prev_mask"] = torch.randn(4, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()

        self.scheduler.latents = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()
        self.scheduler.timestep_input = torch.tensor([600.0], dtype=torch.float32).cuda()
        self.scheduler.audio_adapter_t_emb = torch.randn(1, 3, 5120, dtype=torch.bfloat16).cuda()

        self._infer_cond_uncond(new_inputs, infer_condition=True, graph_name=graph_name)

    def compile(self, compile_shapes):
        self.check_compile_shapes(compile_shapes)
        self.enable_compile_mode("_infer_cond_uncond")

        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == 0:
                self.to_cuda()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cuda()
                self.transformer_weights.non_block_weights_to_cuda()

        for shape in compile_shapes:
            self.start_compile(shape)

        if self.cpu_offload:
            if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
                self.to_cpu()
            elif self.offload_granularity != "model":
                self.pre_weight.to_cpu()
                self.transformer_weights.non_block_weights_to_cpu()

        self.disable_compile_mode("_infer_cond_uncond")
        logger.info(f"[Compile] Compile status: {self.get_compile_status()}")

    def check_compile_shapes(self, compile_shapes):
        for shape in compile_shapes:
            assert shape in [[480, 832], [544, 960], [720, 1280], [832, 480], [960, 544], [1280, 720], [480, 480], [576, 576], [704, 704], [960, 960]]

    def select_graph_for_compile(self):
        logger.info(f"tgt_h, tgt_w : {self.config.get('tgt_h')}, {self.config.get('tgt_w')}")
        self.select_graph("_infer_cond_uncond", f"graph_{self.config.get('tgt_h')}x{self.config.get('tgt_w')}")
        logger.info(f"[Compile] Compile status: {self.get_compile_status()}")