import glob import os from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights from lightx2v.models.networks.wan.weights.transformer_weights import ( WanTransformerWeights, ) class WanAudioModel(WanModel): pre_weight_class = WanPreWeights post_weight_class = WanPostWeights transformer_weight_class = WanTransformerWeights def __init__(self, model_path, config, device): super().__init__(model_path, config, device) def _init_infer_class(self): super()._init_infer_class() self.pre_infer_class = WanAudioPreInfer self.post_infer_class = WanAudioPostInfer self.transformer_infer_class = WanAudioTransformerInfer def set_audio_adapter(self, audio_adapter): self.audio_adapter = audio_adapter self.transformer_infer.set_audio_adapter(self.audio_adapter) class Wan22MoeAudioModel(WanAudioModel): def _load_ckpt(self, unified_dtype, sensitive_layer): safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors")) weight_dict = {} for file_path in safetensors_files: file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer) weight_dict.update(file_weights) return weight_dict