audio_model.py 1.34 KB
Newer Older
PengGao's avatar
PengGao committed
1
import glob
wangshankun's avatar
wangshankun committed
2
import os
PengGao's avatar
PengGao committed
3
4
5

from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer
wangshankun's avatar
wangshankun committed
6
7
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
8
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
wangshankun's avatar
wangshankun committed
9
10
11
12
13
14
15
16
17
18
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)


class WanAudioModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

helloyongyang's avatar
helloyongyang committed
19
20
    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)
wangshankun's avatar
wangshankun committed
21
22

    def _init_infer_class(self):
helloyongyang's avatar
helloyongyang committed
23
        super()._init_infer_class()
wangshankun's avatar
wangshankun committed
24
25
26
        self.pre_infer_class = WanAudioPreInfer
        self.post_infer_class = WanAudioPostInfer

wangshankun's avatar
wangshankun committed
27
28

class Wan22MoeAudioModel(WanAudioModel):
29
    def _load_ckpt(self, unified_dtype, sensitive_layer):
wangshankun's avatar
wangshankun committed
30
31
32
        safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
        weight_dict = {}
        for file_path in safetensors_files:
33
            file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
wangshankun's avatar
wangshankun committed
34
35
            weight_dict.update(file_weights)
        return weight_dict