audio_model.py 2.37 KB
Newer Older
1
2
3
4
import os

import torch.distributed as dist

5
6
7
from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer
wangshankun's avatar
wangshankun committed
8
from lightx2v.models.networks.wan.model import WanModel
9
from lightx2v.models.networks.wan.weights.audio.transformer_weights import WanAudioTransformerWeights
wangshankun's avatar
wangshankun committed
10
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
11
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
12
from lightx2v.utils.utils import load_weights
wangshankun's avatar
wangshankun committed
13
14
15
16
17


class WanAudioModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
18
    transformer_weight_class = WanAudioTransformerWeights
wangshankun's avatar
wangshankun committed
19

helloyongyang's avatar
helloyongyang committed
20
    def __init__(self, model_path, config, device):
21
22
        self.config = config
        self._load_adapter_ckpt()
helloyongyang's avatar
helloyongyang committed
23
        super().__init__(model_path, config, device)
wangshankun's avatar
wangshankun committed
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    def _load_adapter_ckpt(self):
        if self.config.get("adapter_model_path", None) is None:
            if self.config.get("adapter_quantized", False):
                if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f"]:
                    adapter_model_name = "audio_adapter_model_fp8.safetensors"
                elif self.config.get("adapter_quant_scheme", None) == "int8":
                    adapter_model_name = "audio_adapter_model_int8.safetensors"
                else:
                    raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
            else:
                adapter_model_name = "audio_adapter_model.safetensors"
            self.config.adapter_model_path = os.path.join(self.config.model_path, adapter_model_name)

        adapter_offload = self.config.get("cpu_offload", False)
        self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio")
        if not adapter_offload and not dist.is_initialized():
            for key, value in self.adapter_weights_dict.items():
                self.adapter_weights_dict[key] = value.cuda()

wangshankun's avatar
wangshankun committed
44
    def _init_infer_class(self):
helloyongyang's avatar
helloyongyang committed
45
        super()._init_infer_class()
wangshankun's avatar
wangshankun committed
46
47
        self.pre_infer_class = WanAudioPreInfer
        self.post_infer_class = WanAudioPostInfer
48
        self.transformer_infer_class = WanAudioTransformerInfer