audio_model.py 1.64 KB
Newer Older
PengGao's avatar
PengGao committed
1
import glob
wangshankun's avatar
wangshankun committed
2
import os
PengGao's avatar
PengGao committed
3

4
5
6
from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer
wangshankun's avatar
wangshankun committed
7
8
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
9
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
wangshankun's avatar
wangshankun committed
10
11
12
13
14
15
16
17
18
19
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)


class WanAudioModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

helloyongyang's avatar
helloyongyang committed
20
21
    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)
wangshankun's avatar
wangshankun committed
22
23

    def _init_infer_class(self):
helloyongyang's avatar
helloyongyang committed
24
        super()._init_infer_class()
wangshankun's avatar
wangshankun committed
25
26
        self.pre_infer_class = WanAudioPreInfer
        self.post_infer_class = WanAudioPostInfer
27
        self.transformer_infer_class = WanAudioTransformerInfer
wangshankun's avatar
wangshankun committed
28

helloyongyang's avatar
helloyongyang committed
29
30
31
32
    def set_audio_adapter(self, audio_adapter):
        self.audio_adapter = audio_adapter
        self.transformer_infer.set_audio_adapter(self.audio_adapter)

wangshankun's avatar
wangshankun committed
33
34

class Wan22MoeAudioModel(WanAudioModel):
35
    def _load_ckpt(self, unified_dtype, sensitive_layer):
wangshankun's avatar
wangshankun committed
36
37
38
        safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
        weight_dict = {}
        for file_path in safetensors_files:
39
            file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
wangshankun's avatar
wangshankun committed
40
41
            weight_dict.update(file_weights)
        return weight_dict