convert_ldm_original_checkpoint_to_diffusers.py 14.8 KB
Newer Older
Lysandre Debut's avatar
Lysandre Debut committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for the LDM checkpoints. """

import argparse
import json
19

Lysandre Debut's avatar
Lysandre Debut committed
20
import torch
21
22

from diffusers import DDPMScheduler, LDMPipeline, UNet2DModel, VQModel
Lysandre Debut's avatar
Lysandre Debut committed
23
24
25
26
27
28
29


def shave_segments(path, n_shave_prefix_segments=1):
    """
    Removes segments. Positive values shave the first segments, negative shave the last segments.
    """
    if n_shave_prefix_segments >= 0:
30
        return ".".join(path.split(".")[n_shave_prefix_segments:])
Lysandre Debut's avatar
Lysandre Debut committed
31
    else:
32
        return ".".join(path.split(".")[:n_shave_prefix_segments])
Lysandre Debut's avatar
Lysandre Debut committed
33
34
35
36
37
38
39
40


def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
    """
    Updates paths inside resnets to the new naming scheme (local renaming)
    """
    mapping = []
    for old_item in old_list:
41
42
        new_item = old_item.replace("in_layers.0", "norm1")
        new_item = new_item.replace("in_layers.2", "conv1")
Lysandre Debut's avatar
Lysandre Debut committed
43

44
45
        new_item = new_item.replace("out_layers.0", "norm2")
        new_item = new_item.replace("out_layers.3", "conv2")
Lysandre Debut's avatar
Lysandre Debut committed
46

47
48
        new_item = new_item.replace("emb_layers.1", "time_emb_proj")
        new_item = new_item.replace("skip_connection", "conv_shortcut")
Lysandre Debut's avatar
Lysandre Debut committed
49
50
51

        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)

52
        mapping.append({"old": old_item, "new": new_item})
Lysandre Debut's avatar
Lysandre Debut committed
53
54
55
56
57
58
59
60
61
62
63
64

    return mapping


def renew_attention_paths(old_list, n_shave_prefix_segments=0):
    """
    Updates paths inside attentions to the new naming scheme (local renaming)
    """
    mapping = []
    for old_item in old_list:
        new_item = old_item

65
66
        new_item = new_item.replace("norm.weight", "group_norm.weight")
        new_item = new_item.replace("norm.bias", "group_norm.bias")
Lysandre Debut's avatar
Lysandre Debut committed
67

68
69
        new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
        new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
Lysandre Debut's avatar
Lysandre Debut committed
70
71
72

        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)

73
        mapping.append({"old": old_item, "new": new_item})
Lysandre Debut's avatar
Lysandre Debut committed
74
75
76
77

    return mapping


78
79
80
def assign_to_checkpoint(
    paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
Lysandre Debut's avatar
Lysandre Debut committed
81
82
83
84
85
86
87
88
89
90
91
92
    """
    This does the final conversion step: take locally converted weights and apply a global renaming
    to them. It splits attention layers, and takes into account additional replacements
    that may arise.

    Assigns the weights to the new checkpoint.
    """
    assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."

    # Splits the attention layers into three variables.
    if attention_paths_to_split is not None:
        for path, path_map in attention_paths_to_split.items():
Patrick von Platen's avatar
Patrick von Platen committed
93
94
            old_tensor = old_checkpoint[path]
            channels = old_tensor.shape[0] // 3
Lysandre Debut's avatar
Lysandre Debut committed
95

Patrick von Platen's avatar
Patrick von Platen committed
96
97
98
99
100
101
102
            target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)

            num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3

            old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
            query, key, value = old_tensor.split(channels // num_heads, dim=1)

103
104
105
            checkpoint[path_map["query"]] = query.reshape(target_shape)
            checkpoint[path_map["key"]] = key.reshape(target_shape)
            checkpoint[path_map["value"]] = value.reshape(target_shape)
Lysandre Debut's avatar
Lysandre Debut committed
106
107

    for path in paths:
108
        new_path = path["new"]
Lysandre Debut's avatar
Lysandre Debut committed
109
110
111
112
113
114

        # These have already been assigned
        if attention_paths_to_split is not None and new_path in attention_paths_to_split:
            continue

        # Global renaming happens here
115
116
117
        new_path = new_path.replace("middle_block.0", "mid.resnets.0")
        new_path = new_path.replace("middle_block.1", "mid.attentions.0")
        new_path = new_path.replace("middle_block.2", "mid.resnets.1")
Lysandre Debut's avatar
Lysandre Debut committed
118
119
120

        if additional_replacements is not None:
            for replacement in additional_replacements:
121
                new_path = new_path.replace(replacement["old"], replacement["new"])
Lysandre Debut's avatar
Lysandre Debut committed
122

Patrick von Platen's avatar
Patrick von Platen committed
123
124
        # proj_attn.weight has to be converted from conv 1D to linear
        if "proj_attn.weight" in new_path:
125
            checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
Patrick von Platen's avatar
Patrick von Platen committed
126
        else:
127
            checkpoint[new_path] = old_checkpoint[path["old"]]
Lysandre Debut's avatar
Lysandre Debut committed
128
129
130
131


def convert_ldm_checkpoint(checkpoint, config):
    """
Lysandre Debut's avatar
Lysandre Debut committed
132
    Takes a state dict and a config, and returns a converted checkpoint.
Lysandre Debut's avatar
Lysandre Debut committed
133
134
135
    """
    new_checkpoint = {}

136
137
138
139
    new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
    new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
    new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
    new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
Lysandre Debut's avatar
Lysandre Debut committed
140

141
142
    new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
    new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
Lysandre Debut's avatar
Lysandre Debut committed
143

144
145
146
147
    new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
    new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
    new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
    new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
Lysandre Debut's avatar
Lysandre Debut committed
148
149

    # Retrieves the keys for the input blocks only
150
151
152
153
154
    num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})
    input_blocks = {
        layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]
        for layer_id in range(num_input_blocks)
    }
Lysandre Debut's avatar
Lysandre Debut committed
155
156

    # Retrieves the keys for the middle blocks only
157
158
159
160
161
    num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "middle_block" in layer})
    middle_blocks = {
        layer_id: [key for key in checkpoint if f"middle_block.{layer_id}" in key]
        for layer_id in range(num_middle_blocks)
    }
Lysandre Debut's avatar
Lysandre Debut committed
162
163

    # Retrieves the keys for the output blocks only
164
165
166
167
168
    num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "output_blocks" in layer})
    output_blocks = {
        layer_id: [key for key in checkpoint if f"output_blocks.{layer_id}" in key]
        for layer_id in range(num_output_blocks)
    }
Lysandre Debut's avatar
Lysandre Debut committed
169
170

    for i in range(1, num_input_blocks):
171
172
        block_id = (i - 1) // (config["num_res_blocks"] + 1)
        layer_in_block_id = (i - 1) % (config["num_res_blocks"] + 1)
Lysandre Debut's avatar
Lysandre Debut committed
173

174
175
        resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key]
        attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
Lysandre Debut's avatar
Lysandre Debut committed
176

177
178
179
180
181
182
183
        if f"input_blocks.{i}.0.op.weight" in checkpoint:
            new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
                f"input_blocks.{i}.0.op.weight"
            ]
            new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
                f"input_blocks.{i}.0.op.bias"
            ]
Lysandre Debut's avatar
Lysandre Debut committed
184
185

        paths = renew_resnet_paths(resnets)
186
187
188
189
190
        meta_path = {"old": f"input_blocks.{i}.0", "new": f"downsample_blocks.{block_id}.resnets.{layer_in_block_id}"}
        resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"}
        assign_to_checkpoint(
            paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config
        )
Lysandre Debut's avatar
Lysandre Debut committed
191
192
193

        if len(attentions):
            paths = renew_attention_paths(attentions)
194
195
196
197
            meta_path = {
                "old": f"input_blocks.{i}.1",
                "new": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}",
            }
Lysandre Debut's avatar
Lysandre Debut committed
198
            to_split = {
199
200
201
202
                f"input_blocks.{i}.1.qkv.bias": {
                    "key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
                    "query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
                    "value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
Lysandre Debut's avatar
Lysandre Debut committed
203
                },
204
205
206
207
                f"input_blocks.{i}.1.qkv.weight": {
                    "key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
                    "query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
                    "value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
Lysandre Debut's avatar
Lysandre Debut committed
208
209
210
211
212
213
214
                },
            }
            assign_to_checkpoint(
                paths,
                new_checkpoint,
                checkpoint,
                additional_replacements=[meta_path],
Patrick von Platen's avatar
Patrick von Platen committed
215
                attention_paths_to_split=to_split,
216
                config=config,
Lysandre Debut's avatar
Lysandre Debut committed
217
218
219
220
221
222
223
            )

    resnet_0 = middle_blocks[0]
    attentions = middle_blocks[1]
    resnet_1 = middle_blocks[2]

    resnet_0_paths = renew_resnet_paths(resnet_0)
Patrick von Platen's avatar
Patrick von Platen committed
224
    assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint, config=config)
Lysandre Debut's avatar
Lysandre Debut committed
225
226

    resnet_1_paths = renew_resnet_paths(resnet_1)
Patrick von Platen's avatar
Patrick von Platen committed
227
    assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint, config=config)
Lysandre Debut's avatar
Lysandre Debut committed
228
229
230

    attentions_paths = renew_attention_paths(attentions)
    to_split = {
231
232
233
234
        "middle_block.1.qkv.bias": {
            "key": "mid_block.attentions.0.key.bias",
            "query": "mid_block.attentions.0.query.bias",
            "value": "mid_block.attentions.0.value.bias",
Lysandre Debut's avatar
Lysandre Debut committed
235
        },
236
237
238
239
        "middle_block.1.qkv.weight": {
            "key": "mid_block.attentions.0.key.weight",
            "query": "mid_block.attentions.0.query.weight",
            "value": "mid_block.attentions.0.value.weight",
Lysandre Debut's avatar
Lysandre Debut committed
240
241
        },
    }
242
243
244
    assign_to_checkpoint(
        attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config
    )
Lysandre Debut's avatar
Lysandre Debut committed
245
246

    for i in range(num_output_blocks):
247
248
        block_id = i // (config["num_res_blocks"] + 1)
        layer_in_block_id = i % (config["num_res_blocks"] + 1)
Lysandre Debut's avatar
Lysandre Debut committed
249
250
251
252
        output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
        output_block_list = {}

        for layer in output_block_layers:
253
            layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
Lysandre Debut's avatar
Lysandre Debut committed
254
255
256
257
258
259
            if layer_id in output_block_list:
                output_block_list[layer_id].append(layer_name)
            else:
                output_block_list[layer_id] = [layer_name]

        if len(output_block_list) > 1:
260
261
            resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
            attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
Lysandre Debut's avatar
Lysandre Debut committed
262
263
264
265

            resnet_0_paths = renew_resnet_paths(resnets)
            paths = renew_resnet_paths(resnets)

266
            meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
Patrick von Platen's avatar
Patrick von Platen committed
267
            assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
Lysandre Debut's avatar
Lysandre Debut committed
268

269
270
271
272
273
274
275
276
            if ["conv.weight", "conv.bias"] in output_block_list.values():
                index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
                new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
                    f"output_blocks.{i}.{index}.conv.weight"
                ]
                new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
                    f"output_blocks.{i}.{index}.conv.bias"
                ]
Lysandre Debut's avatar
Lysandre Debut committed
277
278
279
280
281
282
283
284

                # Clear attentions as they have been attributed above.
                if len(attentions) == 2:
                    attentions = []

            if len(attentions):
                paths = renew_attention_paths(attentions)
                meta_path = {
285
286
                    "old": f"output_blocks.{i}.1",
                    "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
Lysandre Debut's avatar
Lysandre Debut committed
287
288
                }
                to_split = {
289
290
291
292
                    f"output_blocks.{i}.1.qkv.bias": {
                        "key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
                        "query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
                        "value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
Lysandre Debut's avatar
Lysandre Debut committed
293
                    },
294
295
296
297
                    f"output_blocks.{i}.1.qkv.weight": {
                        "key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
                        "query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
                        "value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
Lysandre Debut's avatar
Lysandre Debut committed
298
299
300
301
302
303
304
                    },
                }
                assign_to_checkpoint(
                    paths,
                    new_checkpoint,
                    checkpoint,
                    additional_replacements=[meta_path],
305
                    attention_paths_to_split=to_split if any("qkv" in key for key in attentions) else None,
Patrick von Platen's avatar
Patrick von Platen committed
306
                    config=config,
Lysandre Debut's avatar
Lysandre Debut committed
307
308
309
310
                )
        else:
            resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
            for path in resnet_0_paths:
311
312
                old_path = ".".join(["output_blocks", str(i), path["old"]])
                new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
Lysandre Debut's avatar
Lysandre Debut committed
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

                new_checkpoint[new_path] = checkpoint[old_path]

    return new_checkpoint


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
    )

    parser.add_argument(
        "--config_file",
        default=None,
        type=str,
        required=True,
        help="The config json file corresponding to the architecture.",
    )

334
    parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
Lysandre Debut's avatar
Lysandre Debut committed
335
336
337
338
339
340
341
342
343

    args = parser.parse_args()

    checkpoint = torch.load(args.checkpoint_path)

    with open(args.config_file) as f:
        config = json.loads(f.read())

    converted_checkpoint = convert_ldm_checkpoint(checkpoint, config)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
344
345
346
347

    if "ldm" in config:
        del config["ldm"]

Patrick von Platen's avatar
Patrick von Platen committed
348
    model = UNet2DModel(**config)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
349
350
351
352
353
354
    model.load_state_dict(converted_checkpoint)

    try:
        scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
        vqvae = VQModel.from_pretrained("/".join(args.checkpoint_path.split("/")[:-1]))

Patrick von Platen's avatar
Patrick von Platen committed
355
        pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
356
357
358
        pipe.save_pretrained(args.dump_path)
    except:
        model.save_pretrained(args.dump_path)