"...composable_kernel_rocm.git" did not exist on "3b1e790ef6326756b2941c272c3ee7abeb257e1b"
conversion_ldm_uncond.py 1.88 KB
Newer Older
1
2
3
import argparse

import torch
4
import yaml
5

6
7
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel

8
9

def convert_ldm_original(checkpoint_path, config_path, output_path):
10
    config = yaml.safe_load(config_path)
11
12
13
14
15
16
17
18
19
    state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
    keys = list(state_dict.keys())

    # extract state_dict for VQVAE
    first_stage_dict = {}
    first_stage_key = "first_stage_model."
    for key in keys:
        if key.startswith(first_stage_key):
            first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
20

21
22
23
24
25
26
    # extract state_dict for UNetLDM
    unet_state_dict = {}
    unet_key = "model.diffusion_model."
    for key in keys:
        if key.startswith(unet_key):
            unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
27

28
29
    vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"]
    unet_init_args = config["model"]["params"]["unet_config"]["params"]
30
31
32
33
34
35
36
37

    vqvae = VQModel(**vqvae_init_args).eval()
    vqvae.load_state_dict(first_stage_dict)

    unet = UNetLDMModel(**unet_init_args).eval()
    unet.load_state_dict(unet_state_dict)

    noise_scheduler = DDIMScheduler(
38
        timesteps=config["model"]["params"]["timesteps"],
39
        beta_schedule="scaled_linear",
40
41
        beta_start=config["model"]["params"]["linear_start"],
        beta_end=config["model"]["params"]["linear_end"],
42
43
44
        clip_sample=False,
    )

Patrick von Platen's avatar
Patrick von Platen committed
45
    pipeline = LDMPipeline(vqvae, unet, noise_scheduler)
46
47
48
49
50
51
52
53
54
55
56
    pipeline.save_pretrained(output_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint_path", type=str, required=True)
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    args = parser.parse_args()

    convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)