import glob import os from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights from lightx2v.models.networks.wan.weights.transformer_weights import ( WanTransformerWeights, ) class WanAudioModel(WanModel): pre_weight_class = WanPreWeights post_weight_class = WanPostWeights transformer_weight_class = WanTransformerWeights def __init__(self, model_path, config, device, seq_p_group=None): super().__init__(model_path, config, device, seq_p_group) def _init_infer_class(self): super()._init_infer_class() self.pre_infer_class = WanAudioPreInfer self.post_infer_class = WanAudioPostInfer class Wan22MoeAudioModel(WanAudioModel): def _load_ckpt(self, unified_dtype, sensitive_layer): safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors")) weight_dict = {} for file_path in safetensors_files: file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer) weight_dict.update(file_weights) return weight_dict