audio_model.py 1.2 KB
Newer Older
1
2
3
from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer
wangshankun's avatar
wangshankun committed
4
5
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
PengGao's avatar
PengGao committed
6
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
wangshankun's avatar
wangshankun committed
7
8
9
10
11
12
13
14
15
16
from lightx2v.models.networks.wan.weights.transformer_weights import (
    WanTransformerWeights,
)


class WanAudioModel(WanModel):
    pre_weight_class = WanPreWeights
    post_weight_class = WanPostWeights
    transformer_weight_class = WanTransformerWeights

helloyongyang's avatar
helloyongyang committed
17
18
    def __init__(self, model_path, config, device):
        super().__init__(model_path, config, device)
wangshankun's avatar
wangshankun committed
19
20

    def _init_infer_class(self):
helloyongyang's avatar
helloyongyang committed
21
        super()._init_infer_class()
wangshankun's avatar
wangshankun committed
22
23
        self.pre_infer_class = WanAudioPreInfer
        self.post_infer_class = WanAudioPostInfer
24
        self.transformer_infer_class = WanAudioTransformerInfer
wangshankun's avatar
wangshankun committed
25

helloyongyang's avatar
helloyongyang committed
26
27
28
    def set_audio_adapter(self, audio_adapter):
        self.audio_adapter = audio_adapter
        self.transformer_infer.set_audio_adapter(self.audio_adapter)