"test/functional_cpu_test.py" did not exist on "f5f79d1d5756240854bd524a55fb5d34180473fa"
convert_weights.py 4 KB
Newer Older
anton-l's avatar
anton-l committed
1
2
3
import torch
from torch import nn

Patrick von Platen's avatar
Patrick von Platen committed
4
5
6
7
8
9
from diffusers import (
    ClassifierFreeGuidanceScheduler,
    GlideDDIMScheduler,
    GLIDESuperResUNetModel,
    GLIDETextToImageUNetModel,
)
anton-l's avatar
anton-l committed
10
from modeling_glide import GLIDE, CLIPTextModel
anton-l's avatar
Style  
anton-l committed
11
12
from transformers import CLIPTextConfig, GPT2Tokenizer

anton-l's avatar
anton-l committed
13
14
15
16

# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict = torch.load("base.pt", map_location="cpu")
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
17
18
19

### Convert the text encoder

anton-l's avatar
anton-l committed
20
config = CLIPTextConfig(
21
22
    vocab_size=50257,
    max_position_embeddings=128,
anton-l's avatar
anton-l committed
23
24
25
26
    hidden_size=512,
    intermediate_size=2048,
    num_hidden_layers=16,
    num_attention_heads=8,
anton-l's avatar
anton-l committed
27
    use_padding_embeddings=True,
anton-l's avatar
anton-l committed
28
29
)
model = CLIPTextModel(config).eval()
Patrick von Platen's avatar
Patrick von Platen committed
30
31
32
tokenizer = GPT2Tokenizer(
    "./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
)
anton-l's avatar
anton-l committed
33
34
35
36
37
38
39
40
41
42
43
44

hf_encoder = model.text_model

hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"]
hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"]
hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"]

hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"]
hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]

for layer_idx in range(config.num_hidden_layers):
    hf_layer = hf_encoder.encoder.layers[layer_idx]
anton-l's avatar
anton-l committed
45
46
    hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"]
    hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"]
anton-l's avatar
anton-l committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60

    hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
    hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]

    hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"]
    hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"]
    hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"]
    hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"]

    hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"]
    hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"]
    hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
    hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]

anton-l's avatar
anton-l committed
61
### Convert the Text-to-Image UNet
62

anton-l's avatar
anton-l committed
63
text2im_model = GLIDETextToImageUNetModel(
64
65
66
67
68
69
70
71
72
73
74
75
    in_channels=3,
    model_channels=192,
    out_channels=6,
    num_res_blocks=3,
    attention_resolutions=(2, 4, 8),
    dropout=0.1,
    channel_mult=(1, 2, 3, 4),
    num_heads=1,
    num_head_channels=64,
    num_heads_upsample=1,
    use_scale_shift_norm=True,
    resblock_updown=True,
anton-l's avatar
anton-l committed
76
    transformer_dim=512,
77
78
)

anton-l's avatar
anton-l committed
79
text2im_model.load_state_dict(state_dict, strict=False)
80

anton-l's avatar
anton-l committed
81
text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
82

anton-l's avatar
anton-l committed
83
84
85
### Convert the Super-Resolution UNet

# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
86
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
anton-l's avatar
anton-l committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

superres_model = GLIDESuperResUNetModel(
    in_channels=6,
    model_channels=192,
    out_channels=6,
    num_res_blocks=2,
    attention_resolutions=(8, 16, 32),
    dropout=0.1,
    channel_mult=(1, 1, 2, 2, 4, 4),
    num_heads=1,
    num_head_channels=64,
    num_heads_upsample=1,
    use_scale_shift_norm=True,
    resblock_updown=True,
)

103
superres_model.load_state_dict(ups_state_dict, strict=False)
anton-l's avatar
anton-l committed
104

105
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
anton-l's avatar
anton-l committed
106

Patrick von Platen's avatar
Patrick von Platen committed
107
108
109
110
111
112
113
114
glide = GLIDE(
    text_unet=text2im_model,
    text_noise_scheduler=text_scheduler,
    text_encoder=model,
    tokenizer=tokenizer,
    upscale_unet=superres_model,
    upscale_noise_scheduler=upscale_scheduler,
)
anton-l's avatar
anton-l committed
115

anton-l's avatar
Style  
anton-l committed
116
glide.save_pretrained("./glide-base")