import os
import torch


def extract_vae_weights(vae_checkpoint_path: str,
                        save_path: str):
    checkpoint = torch.load(vae_checkpoint_path)
    
    weights = checkpoint['model']
    
    new_weights = {}
    for name, params in weights.items():
        if "dis" in name or "prece" in name or "logvar" in name: continue
        name = name.split(".", 1)[1]
        new_weights[name] = params
    
    torch.save(new_weights, save_path)


if __name__ == "__main__":
    from argparse import ArgumentParser
    
    parser = ArgumentParser()
    
    parser.add_argument("--vae_checkpoint_path", type=str)
    
    parser.add_argument("--save_path", type=str)
    
    args = parser.parse_args()
    
    extract_vae_weights(args.vae_checkpoint_path,
                        args.save_path)
