convert_ncsnpp_original_checkpoint_to_diffusers.py 8.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for the NCSNPP checkpoints. """

import argparse
import json
19

20
import torch
21
22

from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
23
24
25
26
27
28


def convert_ncsnpp_checkpoint(checkpoint, config):
    """
    Takes a state dict and the path to
    """
Patrick von Platen's avatar
Patrick von Platen committed
29
30
31
    new_model_architecture = UNet2DModel(**config)
    new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data
    new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data
Patrick von Platen's avatar
upload  
Patrick von Platen committed
32
33
34
35
36
37
38
39
    new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data
    new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data

    new_model_architecture.time_embedding.linear_2.weight.data = checkpoint["all_modules.2.weight"].data
    new_model_architecture.time_embedding.linear_2.bias.data = checkpoint["all_modules.2.bias"].data

    new_model_architecture.conv_in.weight.data = checkpoint["all_modules.3.weight"].data
    new_model_architecture.conv_in.bias.data = checkpoint["all_modules.3.bias"].data
40
41
42
43
44
45
46
47

    new_model_architecture.conv_norm_out.weight.data = checkpoint[list(checkpoint.keys())[-4]].data
    new_model_architecture.conv_norm_out.bias.data = checkpoint[list(checkpoint.keys())[-3]].data
    new_model_architecture.conv_out.weight.data = checkpoint[list(checkpoint.keys())[-2]].data
    new_model_architecture.conv_out.bias.data = checkpoint[list(checkpoint.keys())[-1]].data

    module_index = 4

Patrick von Platen's avatar
upload  
Patrick von Platen committed
48
    def set_attention_weights(new_layer, old_checkpoint, index):
49
50
51
        new_layer.query.weight.data = old_checkpoint[f"all_modules.{index}.NIN_0.W"].data.T
        new_layer.key.weight.data = old_checkpoint[f"all_modules.{index}.NIN_1.W"].data.T
        new_layer.value.weight.data = old_checkpoint[f"all_modules.{index}.NIN_2.W"].data.T
Patrick von Platen's avatar
upload  
Patrick von Platen committed
52

53
54
55
56
57
58
59
60
61
62
        new_layer.query.bias.data = old_checkpoint[f"all_modules.{index}.NIN_0.b"].data
        new_layer.key.bias.data = old_checkpoint[f"all_modules.{index}.NIN_1.b"].data
        new_layer.value.bias.data = old_checkpoint[f"all_modules.{index}.NIN_2.b"].data

        new_layer.proj_attn.weight.data = old_checkpoint[f"all_modules.{index}.NIN_3.W"].data.T
        new_layer.proj_attn.bias.data = old_checkpoint[f"all_modules.{index}.NIN_3.b"].data

        new_layer.group_norm.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
        new_layer.group_norm.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data

Patrick von Platen's avatar
upload  
Patrick von Platen committed
63
    def set_resnet_weights(new_layer, old_checkpoint, index):
64
65
66
67
68
69
70
71
72
        new_layer.conv1.weight.data = old_checkpoint[f"all_modules.{index}.Conv_0.weight"].data
        new_layer.conv1.bias.data = old_checkpoint[f"all_modules.{index}.Conv_0.bias"].data
        new_layer.norm1.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
        new_layer.norm1.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data

        new_layer.conv2.weight.data = old_checkpoint[f"all_modules.{index}.Conv_1.weight"].data
        new_layer.conv2.bias.data = old_checkpoint[f"all_modules.{index}.Conv_1.bias"].data
        new_layer.norm2.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.weight"].data
        new_layer.norm2.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.bias"].data
Patrick von Platen's avatar
upload  
Patrick von Platen committed
73

74
75
76
77
78
79
80
81
82
83
        new_layer.time_emb_proj.weight.data = old_checkpoint[f"all_modules.{index}.Dense_0.weight"].data
        new_layer.time_emb_proj.bias.data = old_checkpoint[f"all_modules.{index}.Dense_0.bias"].data

        if new_layer.in_channels != new_layer.out_channels or new_layer.up or new_layer.down:
            new_layer.conv_shortcut.weight.data = old_checkpoint[f"all_modules.{index}.Conv_2.weight"].data
            new_layer.conv_shortcut.bias.data = old_checkpoint[f"all_modules.{index}.Conv_2.bias"].data

    for i, block in enumerate(new_model_architecture.downsample_blocks):
        has_attentions = hasattr(block, "attentions")
        for j in range(len(block.resnets)):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
84
            set_resnet_weights(block.resnets[j], checkpoint, module_index)
85
86
            module_index += 1
            if has_attentions:
Patrick von Platen's avatar
upload  
Patrick von Platen committed
87
                set_attention_weights(block.attentions[j], checkpoint, module_index)
88
                module_index += 1
Patrick von Platen's avatar
upload  
Patrick von Platen committed
89

90
        if hasattr(block, "downsamplers") and block.downsamplers is not None:
Patrick von Platen's avatar
upload  
Patrick von Platen committed
91
            set_resnet_weights(block.resnet_down, checkpoint, module_index)
92
93
94
95
96
            module_index += 1
            block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.Conv_0.weight"].data
            block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data
            module_index += 1

Patrick von Platen's avatar
Patrick von Platen committed
97
    set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index)
98
    module_index += 1
Patrick von Platen's avatar
Patrick von Platen committed
99
    set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index)
100
    module_index += 1
Patrick von Platen's avatar
Patrick von Platen committed
101
    set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index)
102
103
    module_index += 1

Patrick von Platen's avatar
Patrick von Platen committed
104
    for i, block in enumerate(new_model_architecture.up_blocks):
105
106
        has_attentions = hasattr(block, "attentions")
        for j in range(len(block.resnets)):
Patrick von Platen's avatar
upload  
Patrick von Platen committed
107
            set_resnet_weights(block.resnets[j], checkpoint, module_index)
108
109
            module_index += 1
        if has_attentions:
Patrick von Platen's avatar
upload  
Patrick von Platen committed
110
111
112
            set_attention_weights(
                block.attentions[0], checkpoint, module_index
            )  # why can there only be a single attention layer for up?
113
            module_index += 1
Patrick von Platen's avatar
upload  
Patrick von Platen committed
114

115
116
117
118
119
120
121
        if hasattr(block, "resnet_up") and block.resnet_up is not None:
            block.skip_norm.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
            block.skip_norm.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
            module_index += 1
            block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
            block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
            module_index += 1
Patrick von Platen's avatar
upload  
Patrick von Platen committed
122
            set_resnet_weights(block.resnet_up, checkpoint, module_index)
123
            module_index += 1
Patrick von Platen's avatar
upload  
Patrick von Platen committed
124

125
126
127
128
129
130
131
132
    new_model_architecture.conv_norm_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
    new_model_architecture.conv_norm_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
    module_index += 1
    new_model_architecture.conv_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
    new_model_architecture.conv_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data

    return new_model_architecture.state_dict()

Patrick von Platen's avatar
upload  
Patrick von Platen committed
133

134
135
136
137
if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
Patrick von Platen's avatar
upload  
Patrick von Platen committed
138
        "--checkpoint_path",
Patrick von Platen's avatar
Patrick von Platen committed
139
        default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin",
Patrick von Platen's avatar
upload  
Patrick von Platen committed
140
141
142
        type=str,
        required=False,
        help="Path to the checkpoint to convert.",
143
144
145
146
147
148
149
150
151
152
153
    )

    parser.add_argument(
        "--config_file",
        default="/Users/arthurzucker/Work/diffusers/ArthurZ/config.json",
        type=str,
        required=False,
        help="The config json file corresponding to the architecture.",
    )

    parser.add_argument(
Patrick von Platen's avatar
upload  
Patrick von Platen committed
154
155
156
157
158
        "--dump_path",
        default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model_new.pt",
        type=str,
        required=False,
        help="Path to the output model.",
159
160
161
162
163
164
165
166
167
    )

    args = parser.parse_args()

    checkpoint = torch.load(args.checkpoint_path, map_location="cpu")

    with open(args.config_file) as f:
        config = json.loads(f.read())

Patrick von Platen's avatar
upload  
Patrick von Platen committed
168
169
170
171
172
173
174
175
    converted_checkpoint = convert_ncsnpp_checkpoint(
        checkpoint,
        config,
    )

    if "sde" in config:
        del config["sde"]

Patrick von Platen's avatar
Patrick von Platen committed
176
    model = UNet2DModel(**config)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
177
178
179
180
    model.load_state_dict(converted_checkpoint)

    try:
        scheduler = ScoreSdeVeScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
181

Patrick von Platen's avatar
upload  
Patrick von Platen committed
182
183
184
185
        pipe = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
        pipe.save_pretrained(args.dump_path)
    except:
        model.save_pretrained(args.dump_path)