import os import torch def extract_vae_weights(vae_checkpoint_path: str, save_path: str): checkpoint = torch.load(vae_checkpoint_path) weights = checkpoint['model'] new_weights = {} for name, params in weights.items(): if "dis" in name or "prece" in name or "logvar" in name: continue name = name.split(".", 1)[1] new_weights[name] = params torch.save(new_weights, save_path) if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument("--vae_checkpoint_path", type=str) parser.add_argument("--save_path", type=str) args = parser.parse_args() extract_vae_weights(args.vae_checkpoint_path, args.save_path)