convert_ldm_original_checkpoint_to_diffusers.py 14.8 KB
Newer Older
Lysandre Debut's avatar
Lysandre Debut committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for the LDM checkpoints. """

import argparse
import json
19

Lysandre Debut's avatar
Lysandre Debut committed
20
import torch
21
22

from diffusers import DDPMScheduler, LDMPipeline, UNet2DModel, VQModel
Lysandre Debut's avatar
Lysandre Debut committed
23
24
25
26
27
28
29


def shave_segments(path, n_shave_prefix_segments=1):
    """
    Removes segments. Positive values shave the first segments, negative shave the last segments.
    """
    if n_shave_prefix_segments >= 0:
30
        return ".".join(path.split(".")[n_shave_prefix_segments:])
Lysandre Debut's avatar
Lysandre Debut committed
31
    else:
32
        return ".".join(path.split(".")[:n_shave_prefix_segments])
Lysandre Debut's avatar
Lysandre Debut committed
33
34
35
36
37
38
39
40


def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
    """
    Updates paths inside resnets to the new naming scheme (local renaming)
    """
    mapping = []
    for old_item in old_list:
41
42
        new_item = old_item.replace("in_layers.0", "norm1")
        new_item = new_item.replace("in_layers.2", "conv1")
Lysandre Debut's avatar
Lysandre Debut committed
43

44
45
        new_item = new_item.replace("out_layers.0", "norm2")
        new_item = new_item.replace("out_layers.3", "conv2")
Lysandre Debut's avatar
Lysandre Debut committed
46

47
48
        new_item = new_item.replace("emb_layers.1", "time_emb_proj")
        new_item = new_item.replace("skip_connection", "conv_shortcut")
Lysandre Debut's avatar
Lysandre Debut committed
49
50
51

        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)

52
        mapping.append({"old": old_item, "new": new_item})
Lysandre Debut's avatar
Lysandre Debut committed
53
54
55
56
57
58
59
60
61
62
63
64

    return mapping


def renew_attention_paths(old_list, n_shave_prefix_segments=0):
    """
    Updates paths inside attentions to the new naming scheme (local renaming)
    """
    mapping = []
    for old_item in old_list:
        new_item = old_item

65
66
        new_item = new_item.replace("norm.weight", "group_norm.weight")
        new_item = new_item.replace("norm.bias", "group_norm.bias")
Lysandre Debut's avatar
Lysandre Debut committed
67

68
69
        new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
        new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
Lysandre Debut's avatar
Lysandre Debut committed
70
71
72

        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)

73
        mapping.append({"old": old_item, "new": new_item})
Lysandre Debut's avatar
Lysandre Debut committed
74
75
76
77

    return mapping


78
79
80
def assign_to_checkpoint(
    paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
Lysandre Debut's avatar
Lysandre Debut committed
81
82
83
84
85
86
87
88
89
90
91
92
    """
    This does the final conversion step: take locally converted weights and apply a global renaming
    to them. It splits attention layers, and takes into account additional replacements
    that may arise.

    Assigns the weights to the new checkpoint.
    """
    assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."

    # Splits the attention layers into three variables.
    if attention_paths_to_split is not None:
        for path, path_map in attention_paths_to_split.items():
Patrick von Platen's avatar
Patrick von Platen committed
93
94
            old_tensor = old_checkpoint[path]
            channels = old_tensor.shape[0] // 3
Lysandre Debut's avatar
Lysandre Debut committed
95

Patrick von Platen's avatar
Patrick von Platen committed
96
97
98
99
100
101
102
            target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)

            num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3

            old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
            query, key, value = old_tensor.split(channels // num_heads, dim=1)

103
104
105
            checkpoint[path_map["query"]] = query.reshape(target_shape)
            checkpoint[path_map["key"]] = key.reshape(target_shape)
            checkpoint[path_map["value"]] = value.reshape(target_shape)
Lysandre Debut's avatar
Lysandre Debut committed
106
107

    for path in paths:
108
        new_path = path["new"]
Lysandre Debut's avatar
Lysandre Debut committed
109
110
111
112
113
114

        # These have already been assigned
        if attention_paths_to_split is not None and new_path in attention_paths_to_split:
            continue

        # Global renaming happens here
115
116
117
        new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
        new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
        new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
Lysandre Debut's avatar
Lysandre Debut committed
118
119
120

        if additional_replacements is not None:
            for replacement in additional_replacements:
121
                new_path = new_path.replace(replacement["old"], replacement["new"])
Lysandre Debut's avatar
Lysandre Debut committed
122

Patrick von Platen's avatar
Patrick von Platen committed
123
124
        # proj_attn.weight has to be converted from conv 1D to linear
        if "proj_attn.weight" in new_path:
125
            checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
Patrick von Platen's avatar
Patrick von Platen committed
126
        else:
127
            checkpoint[new_path] = old_checkpoint[path["old"]]
Lysandre Debut's avatar
Lysandre Debut committed
128
129
130
131


def convert_ldm_checkpoint(checkpoint, config):
    """
Lysandre Debut's avatar
Lysandre Debut committed
132
    Takes a state dict and a config, and returns a converted checkpoint.
Lysandre Debut's avatar
Lysandre Debut committed
133
134
135
    """
    new_checkpoint = {}

136
137
138
139
    new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
    new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
    new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
    new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
Lysandre Debut's avatar
Lysandre Debut committed
140

141
142
    new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
    new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
Lysandre Debut's avatar
Lysandre Debut committed
143

144
145
146
147
    new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
    new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
    new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
    new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
Lysandre Debut's avatar
Lysandre Debut committed
148
149

    # Retrieves the keys for the input blocks only
150
151
152
153
154
    num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})
    input_blocks = {
        layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]
        for layer_id in range(num_input_blocks)
    }
Lysandre Debut's avatar
Lysandre Debut committed
155
156

    # Retrieves the keys for the middle blocks only
157
158
159
160
161
    num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "middle_block" in layer})
    middle_blocks = {
        layer_id: [key for key in checkpoint if f"middle_block.{layer_id}" in key]
        for layer_id in range(num_middle_blocks)
    }
Lysandre Debut's avatar
Lysandre Debut committed
162
163

    # Retrieves the keys for the output blocks only
164
165
166
167
168
    num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "output_blocks" in layer})
    output_blocks = {
        layer_id: [key for key in checkpoint if f"output_blocks.{layer_id}" in key]
        for layer_id in range(num_output_blocks)
    }
Lysandre Debut's avatar
Lysandre Debut committed
169
170

    for i in range(1, num_input_blocks):
171
172
        block_id = (i - 1) // (config["num_res_blocks"] + 1)
        layer_in_block_id = (i - 1) % (config["num_res_blocks"] + 1)
Lysandre Debut's avatar
Lysandre Debut committed
173

174
175
        resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key]
        attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
Lysandre Debut's avatar
Lysandre Debut committed
176

177
        if f"input_blocks.{i}.0.op.weight" in checkpoint:
178
            new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
179
180
                f"input_blocks.{i}.0.op.weight"
            ]
181
            new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
182
183
                f"input_blocks.{i}.0.op.bias"
            ]
184
            continue
Lysandre Debut's avatar
Lysandre Debut committed
185
186

        paths = renew_resnet_paths(resnets)
187
        meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
188
189
190
191
        resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"}
        assign_to_checkpoint(
            paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config
        )
Lysandre Debut's avatar
Lysandre Debut committed
192
193
194

        if len(attentions):
            paths = renew_attention_paths(attentions)
195
196
            meta_path = {
                "old": f"input_blocks.{i}.1",
197
                "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}",
198
            }
Lysandre Debut's avatar
Lysandre Debut committed
199
            to_split = {
200
                f"input_blocks.{i}.1.qkv.bias": {
201
202
203
                    "key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
                    "query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
                    "value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
Lysandre Debut's avatar
Lysandre Debut committed
204
                },
205
                f"input_blocks.{i}.1.qkv.weight": {
206
207
208
                    "key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
                    "query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
                    "value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
Lysandre Debut's avatar
Lysandre Debut committed
209
210
211
212
213
214
215
                },
            }
            assign_to_checkpoint(
                paths,
                new_checkpoint,
                checkpoint,
                additional_replacements=[meta_path],
Patrick von Platen's avatar
Patrick von Platen committed
216
                attention_paths_to_split=to_split,
217
                config=config,
Lysandre Debut's avatar
Lysandre Debut committed
218
219
220
221
222
223
224
            )

    resnet_0 = middle_blocks[0]
    attentions = middle_blocks[1]
    resnet_1 = middle_blocks[2]

    resnet_0_paths = renew_resnet_paths(resnet_0)
Patrick von Platen's avatar
Patrick von Platen committed
225
    assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint, config=config)
Lysandre Debut's avatar
Lysandre Debut committed
226
227

    resnet_1_paths = renew_resnet_paths(resnet_1)
Patrick von Platen's avatar
Patrick von Platen committed
228
    assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint, config=config)
Lysandre Debut's avatar
Lysandre Debut committed
229
230
231

    attentions_paths = renew_attention_paths(attentions)
    to_split = {
232
233
234
235
        "middle_block.1.qkv.bias": {
            "key": "mid_block.attentions.0.key.bias",
            "query": "mid_block.attentions.0.query.bias",
            "value": "mid_block.attentions.0.value.bias",
Lysandre Debut's avatar
Lysandre Debut committed
236
        },
237
238
239
240
        "middle_block.1.qkv.weight": {
            "key": "mid_block.attentions.0.key.weight",
            "query": "mid_block.attentions.0.query.weight",
            "value": "mid_block.attentions.0.value.weight",
Lysandre Debut's avatar
Lysandre Debut committed
241
242
        },
    }
243
244
245
    assign_to_checkpoint(
        attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config
    )
Lysandre Debut's avatar
Lysandre Debut committed
246
247

    for i in range(num_output_blocks):
248
249
        block_id = i // (config["num_res_blocks"] + 1)
        layer_in_block_id = i % (config["num_res_blocks"] + 1)
Lysandre Debut's avatar
Lysandre Debut committed
250
251
252
253
        output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
        output_block_list = {}

        for layer in output_block_layers:
254
            layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
Lysandre Debut's avatar
Lysandre Debut committed
255
256
257
258
259
260
            if layer_id in output_block_list:
                output_block_list[layer_id].append(layer_name)
            else:
                output_block_list[layer_id] = [layer_name]

        if len(output_block_list) > 1:
261
262
            resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
            attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
Lysandre Debut's avatar
Lysandre Debut committed
263
264
265
266

            resnet_0_paths = renew_resnet_paths(resnets)
            paths = renew_resnet_paths(resnets)

267
            meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
Patrick von Platen's avatar
Patrick von Platen committed
268
            assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
Lysandre Debut's avatar
Lysandre Debut committed
269

270
271
272
273
274
275
276
277
            if ["conv.weight", "conv.bias"] in output_block_list.values():
                index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
                new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
                    f"output_blocks.{i}.{index}.conv.weight"
                ]
                new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
                    f"output_blocks.{i}.{index}.conv.bias"
                ]
Lysandre Debut's avatar
Lysandre Debut committed
278
279
280
281
282
283
284
285

                # Clear attentions as they have been attributed above.
                if len(attentions) == 2:
                    attentions = []

            if len(attentions):
                paths = renew_attention_paths(attentions)
                meta_path = {
286
287
                    "old": f"output_blocks.{i}.1",
                    "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
Lysandre Debut's avatar
Lysandre Debut committed
288
289
                }
                to_split = {
290
291
292
293
                    f"output_blocks.{i}.1.qkv.bias": {
                        "key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
                        "query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
                        "value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
Lysandre Debut's avatar
Lysandre Debut committed
294
                    },
295
296
297
298
                    f"output_blocks.{i}.1.qkv.weight": {
                        "key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
                        "query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
                        "value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
Lysandre Debut's avatar
Lysandre Debut committed
299
300
301
302
303
304
305
                    },
                }
                assign_to_checkpoint(
                    paths,
                    new_checkpoint,
                    checkpoint,
                    additional_replacements=[meta_path],
306
                    attention_paths_to_split=to_split if any("qkv" in key for key in attentions) else None,
Patrick von Platen's avatar
Patrick von Platen committed
307
                    config=config,
Lysandre Debut's avatar
Lysandre Debut committed
308
309
310
311
                )
        else:
            resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
            for path in resnet_0_paths:
312
313
                old_path = ".".join(["output_blocks", str(i), path["old"]])
                new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
Lysandre Debut's avatar
Lysandre Debut committed
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334

                new_checkpoint[new_path] = checkpoint[old_path]

    return new_checkpoint


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
    )

    parser.add_argument(
        "--config_file",
        default=None,
        type=str,
        required=True,
        help="The config json file corresponding to the architecture.",
    )

335
    parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
Lysandre Debut's avatar
Lysandre Debut committed
336
337
338
339
340
341
342
343
344

    args = parser.parse_args()

    checkpoint = torch.load(args.checkpoint_path)

    with open(args.config_file) as f:
        config = json.loads(f.read())

    converted_checkpoint = convert_ldm_checkpoint(checkpoint, config)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
345
346
347
348

    if "ldm" in config:
        del config["ldm"]

Patrick von Platen's avatar
Patrick von Platen committed
349
    model = UNet2DModel(**config)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
350
351
352
353
354
355
    model.load_state_dict(converted_checkpoint)

    try:
        scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
        vqvae = VQModel.from_pretrained("/".join(args.checkpoint_path.split("/")[:-1]))

Patrick von Platen's avatar
Patrick von Platen committed
356
        pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae)
Patrick von Platen's avatar
upload  
Patrick von Platen committed
357
        pipe.save_pretrained(args.dump_path)
358
    except:  # noqa: E722
Patrick von Platen's avatar
upload  
Patrick von Platen committed
359
        model.save_pretrained(args.dump_path)