conversion_glide.py 4.05 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
import torch
from torch import nn

Patrick von Platen's avatar
Patrick von Platen committed
4
5
from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GlideSuperResUNetModel, GlideTextToImageUNetModel
from diffusers.pipelines.pipeline_glide import Glide, CLIPTextModel
Patrick von Platen's avatar
Patrick von Platen committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import CLIPTextConfig, GPT2Tokenizer


# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict = torch.load("base.pt", map_location="cpu")
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}

### Convert the text encoder

config = CLIPTextConfig(
    vocab_size=50257,
    max_position_embeddings=128,
    hidden_size=512,
    intermediate_size=2048,
    num_hidden_layers=16,
    num_attention_heads=8,
    use_padding_embeddings=True,
)
model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer(
    "./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
)

hf_encoder = model.text_model

hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"]
hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"]
hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"]

hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"]
hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]

for layer_idx in range(config.num_hidden_layers):
    hf_layer = hf_encoder.encoder.layers[layer_idx]
    hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"]
    hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"]

    hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
    hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]

    hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"]
    hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"]
    hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"]
    hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"]

    hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"]
    hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"]
    hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
    hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]

### Convert the Text-to-Image UNet

Patrick von Platen's avatar
Patrick von Platen committed
58
text2im_model = GlideTextToImageUNetModel(
Patrick von Platen's avatar
Patrick von Platen committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    in_channels=3,
    model_channels=192,
    out_channels=6,
    num_res_blocks=3,
    attention_resolutions=(2, 4, 8),
    dropout=0.1,
    channel_mult=(1, 2, 3, 4),
    num_heads=1,
    num_head_channels=64,
    num_heads_upsample=1,
    use_scale_shift_norm=True,
    resblock_updown=True,
    transformer_dim=512,
)

text2im_model.load_state_dict(state_dict, strict=False)

text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")

### Convert the Super-Resolution UNet

# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
ups_state_dict = torch.load("upsample.pt", map_location="cpu")

Patrick von Platen's avatar
Patrick von Platen committed
83
superres_model = GlideSuperResUNetModel(
Patrick von Platen's avatar
Patrick von Platen committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    in_channels=6,
    model_channels=192,
    out_channels=6,
    num_res_blocks=2,
    attention_resolutions=(8, 16, 32),
    dropout=0.1,
    channel_mult=(1, 1, 2, 2, 4, 4),
    num_heads=1,
    num_head_channels=64,
    num_heads_upsample=1,
    use_scale_shift_norm=True,
    resblock_updown=True,
)

superres_model.load_state_dict(ups_state_dict, strict=False)

anton-l's avatar
anton-l committed
100
101
102
upscale_scheduler = DDIMScheduler(
    timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
)
Patrick von Platen's avatar
Patrick von Platen committed
103

Patrick von Platen's avatar
Patrick von Platen committed
104
glide = Glide(
Patrick von Platen's avatar
Patrick von Platen committed
105
106
107
108
109
110
111
112
113
    text_unet=text2im_model,
    text_noise_scheduler=text_scheduler,
    text_encoder=model,
    tokenizer=tokenizer,
    upscale_unet=superres_model,
    upscale_noise_scheduler=upscale_scheduler,
)

glide.save_pretrained("./glide-base")