convert_megatron_gpt2_checkpoint.py 16.8 KB
Newer Older
liangjing's avatar
update  
liangjing committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
####################################################################################################

# Copyright (c) 2021-, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

####################################################################################################

#
# Note: If when running this conversion script you're getting an exception:
#     ModuleNotFoundError: No module named 'megatron.model.enums'
# you need to tell python where to find the clone of Megatron-LM, e.g.:
#
# cd /tmp
# git clone https://github.com/NVIDIA/Megatron-LM
# PYTHONPATH=/tmp/Megatron-LM python src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py ...
#
# if you already have it cloned elsewhere, simply adjust the path to the existing path
#
# If the training was done using a Megatron-LM fork, e.g.,
# https://github.com/microsoft/Megatron-DeepSpeed/ then chances are that you need to have that one
# in your path, i.e., /path/to/Megatron-DeepSpeed/
#

import argparse
import os
import re
import zipfile

import torch

from transformers import AutoTokenizer, GPT2Config
import pdb

####################################################################################################


def recursive_print(name, val, spaces=0):
    # Format the message.
    if name is None:
        msg = None
    else:
        fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
        msg = fmt.format(name)

    # Print and recurse (if needed).
    if isinstance(val, dict):
        if msg is not None:
            print(msg)
        for k in val.keys():
            recursive_print(k, val[k], spaces + 2)
    elif isinstance(val, torch.Tensor):
        print(msg, ":", val.size())
    else:
        print(msg, ":", val)


def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size):
    # Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :]
    # for compatibility with later versions of NVIDIA Megatron-LM.
    # The inverse operation is performed inside Megatron-LM to read checkpoints:
    # https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209
    # If param is the weight tensor of the self-attention block, the returned tensor
    # will have to be transposed one more time to be read by HuggingFace GPT2.
    input_shape = param.size()
    if checkpoint_version == 1.0:
        # version 1.0 stores [num_heads * hidden_size * num_splits, :]
        saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:]
        param = param.view(*saved_shape)
        param = param.transpose(0, 2)
        param = param.transpose(1, 2).contiguous()
    elif checkpoint_version >= 2.0:
        # other versions store [num_heads * num_splits * hidden_size, :]
        saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:]
        param = param.view(*saved_shape)
        param = param.transpose(0, 1).contiguous()
    param = param.view(*input_shape)
    return param


####################################################################################################


def convert_megatron_checkpoint(args, input_state_dict, config, origin_tp_degree=1):
    # The converted output model.
    output_state_dict = {}

    # old versions did not store training args
    ds_args = input_state_dict.get("args", None)
    if ds_args is not None:
        # do not make the user write a config file when the exact dimensions/sizes are already in the checkpoint
        # from pprint import pprint
        # pprint(vars(ds_args))

        config.vocab_size = ds_args.padded_vocab_size
        config.n_positions = ds_args.max_position_embeddings
        config.n_embd = ds_args.hidden_size
        config.n_layer = ds_args.num_layers
        config.n_head = ds_args.num_attention_heads
        config.n_inner = ds_args.ffn_hidden_size
        # pprint(config)

    # The number of heads.
    heads = config.n_head
    # The hidden_size per head.
    hidden_size_per_head = config.n_embd // config.n_head
    # Megatron-LM checkpoint version
    if "checkpoint_version" in input_state_dict.keys():
        checkpoint_version = input_state_dict["checkpoint_version"]
    else:
        checkpoint_version = 0.0

    # The model.
    model = input_state_dict["model"] if "model" in input_state_dict else input_state_dict["module"]
    for key in model.keys():
        print(f">> {key} in model: {model[key].keys()}")
        for sub_key in model[key].keys():
            print(f"\t>> {sub_key} in {key} in model: {model[key][sub_key].keys()}")
    # The language model.
    lm = model["language_model"]
    # The embeddings.
    embeddings = lm["embedding"]

    # The word embeddings.
    word_embeddings = embeddings["word_embeddings"]["weight"]
    # Truncate the embedding table to vocab_size rows.
    word_embeddings = word_embeddings[: config.vocab_size, :]
    #output_state_dict["transformer.wte.weight"] = word_embeddings
    output_state_dict["model.embed_tokens.weight"] = word_embeddings

    # for LLAMA2
    lm_head = lm['output_layer']['weight'] if 'output_layer' in lm else word_embeddings

    # The position embeddings.
    #pos_embeddings = embeddings["position_embeddings"]["weight"]
    # Read the causal mask dimension (seqlen). [max_sequence_length, hidden_size]
    #n_positions = pos_embeddings.size(0)
    n_positions = config.n_positions
    if n_positions != config.n_positions:
        raise ValueError(
            f"pos_embeddings.max_sequence_length={n_positions} and config.n_positions={config.n_positions} don't match"
        )
    # Store the position embeddings.
    #output_state_dict["transformer.wpe.weight"] = pos_embeddings

    # The transformer.
    transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"]

    # The regex to extract layer names.
    layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")

    # The simple map of names for "automated" rules.
    megatron_to_transformers = {
        #"attention.dense": ".attn.c_proj.",
        #"self_attention.dense": ".attn.c_proj.",
        "attention.dense": ".self_attn.o_proj.",
        "self_attention.dense": ".self_attn.o_proj.",
        "mlp.dense_h_to_4h": ".mlp.c_fc.",
        #"mlp.dense_4h_to_h": ".mlp.c_proj.",
        "mlp.dense_4h_to_h": ".mlp.down_proj.",
    }

    # Extract the layers.
    for key, val in transformer.items():
        # Match the name.
        m = layer_re.match(key)

        # Stop if that's not a layer
        if m is None:
            break

        # The index of the layer.
        layer_idx = int(m.group(1))
        # The name of the operation.
        op_name = m.group(2)
        # Is it a weight or a bias?
        weight_or_bias = m.group(3)

        # The name of the layer.
        layer_name = f"transformer.h.{layer_idx}"
        layer_name = f"model.layers.{layer_idx}"

        # For layernorm(s), simply store the layer norm.
        if op_name.endswith("layernorm"):
            #ln_name = "ln_1" if op_name.startswith("input") else "ln_2"
            ln_name = op_name
            output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val

        # Transpose the QKV matrix.
        elif (
            op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
        ) and weight_or_bias == "weight":
            # Insert a tensor of 1x1xDxD bias.
            #causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.float16)).view(
            #    1, 1, n_positions, n_positions
            #)
            #output_state_dict[layer_name + ".attn.bias"] = causal_mask

            # Insert a "dummy" tensor for masked_bias.
            #masked_bias = torch.tensor(-1e4, dtype=torch.float16)
            #output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias

            out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
            # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D.
            #out_val = out_val.transpose(0, 1).contiguous()
            out_val = out_val.contiguous()
            # Store.
            #output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val
            output_state_dict[layer_name + ".self_attn.q_proj.weight"] = out_val[:config.n_embd, :]
            output_state_dict[layer_name + ".self_attn.k_proj.weight"] = out_val[config.n_embd:config.n_embd * 2, :]
            output_state_dict[layer_name + ".self_attn.v_proj.weight"] = out_val[config.n_embd * 2 :, :]
        elif (
            op_name == "self_attention.query"
        ) and weight_or_bias == "weight":
            out_val = fix_query_key_value_ordering(val, checkpoint_version, 1, heads, hidden_size_per_head)
            #out_val = out_val.transpose(0, 1).contiguous()
            out_val = out_val.contiguous()
            output_state_dict[layer_name + ".self_attn.q_proj.weight"] = out_val
        elif (
            op_name == "self_attention.key_value"
        )   and weight_or_bias == "weight":
            #print(f">> key_value origin size: {val.size()}")
            size_per_weight = val.size(0) // 2
            #please set the NUM_KV_HEADS used to replace number "4" in fix_query_key_value_ordering function
            out_val = fix_query_key_value_ordering(val, checkpoint_version, 2, 4, hidden_size_per_head)
            #print(f">> key_value output size: {out_val.size()}")
            out_val = out_val.contiguous()
            output_state_dict[layer_name + ".self_attn.k_proj.weight"] = out_val[:size_per_weight, :]
            output_state_dict[layer_name + ".self_attn.v_proj.weight"] = out_val[size_per_weight:, :]
        # Transpose the bias.
        elif (
            op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
        ) and weight_or_bias == "bias":
            out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
            # Store. No change of shape.
            output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val
        elif op_name == "mlp.dense_h_to_4h":
            # this 2 lines for TP=1 (swiglu)
            if origin_tp_degree == 1:
                output_state_dict[layer_name + ".mlp.gate_proj.weight"] = val[:config.n_inner, :]
                output_state_dict[layer_name + ".mlp.up_proj.weight"] = val[config.n_inner:, :]
            elif origin_tp_degree == 2:
            # this 2 lines for TP=2 (swiglu)
                output_state_dict[layer_name + ".mlp.gate_proj.weight"] = torch.cat([val[:config.n_inner//2, :], val[config.n_inner:config.n_inner + config.n_inner // 2, :]])
                output_state_dict[layer_name + ".mlp.up_proj.weight"] = torch.cat([val[config.n_inner//2:config.n_inner, :], val[config.n_inner + config.n_inner // 2:, :]])
            elif origin_tp_degree == 4:
                output_state_dict[layer_name + ".mlp.gate_proj.weight"] = torch.cat([val[:config.n_inner//4, :],val[config.n_inner//2:config.n_inner//2+config.n_inner//4, :],val[config.n_inner:config.n_inner+config.n_inner//4, :],val[config.n_inner+config.n_inner//2:config.n_inner+config.n_inner//4*3, :] ])
                output_state_dict[layer_name + ".mlp.up_proj.weight"] = torch.cat([val[config.n_inner//4:config.n_inner//2,:], val[config.n_inner//2+config.n_inner//4:config.n_inner,:],val[config.n_inner+config.n_inner//4:config.n_inner+config.n_inner//2,:],val[config.n_inner+config.n_inner//4*3:config.n_inner*2, :] ])
            else:
                raise ValueError("Not Implemented Yet for TP /= 1 && 2 && 4.")
        # Transpose the weights.
        elif weight_or_bias == "weight":
            out_name = megatron_to_transformers[op_name]
            output_state_dict[layer_name + out_name + "weight"] = val#.transpose(0, 1)

        # Copy the bias.
        elif weight_or_bias == "bias":
            out_name = megatron_to_transformers[op_name]
            output_state_dict[layer_name + out_name + "bias"] = val

    # DEBUG.
    assert config.n_layer == layer_idx + 1

    # The final layernorm.
    #output_state_dict["transformer.ln_f.weight"] = transformer["final_layernorm.weight"]
    #pdb.set_trace()
    output_state_dict["model.norm.weight"] = transformer["final_layernorm.weight"]
    #output_state_dict["transformer.ln_f.bias"] = transformer["final_layernorm.bias"]

    # For LM head, transformers' wants the matrix to weight embeddings.
    output_state_dict["lm_head.weight"] = lm_head

    # transform the key for LLAMA2
    transform_dict = {
        "transformer.h": "model.layers",

    }
    # It should be done!
    return output_state_dict


####################################################################################################


def main():
    # Create the argument parser.
    parser = argparse.ArgumentParser()
    parser.add_argument("--print-checkpoint-structure", action="store_true")
    parser.add_argument(
        "path_to_checkpoint",
        type=str,
        help="Path to the checkpoint file (.zip archive or direct .pt file)",
    )
    parser.add_argument(
        "--config_file",
        default="",
        type=str,
        help="An optional config json file describing the pre-trained model.",
    )
    args = parser.parse_args()

    # Extract the basename.
    basename = os.path.dirname(args.path_to_checkpoint)

    # Load the model.
    # the .zip is very optional, let's keep it for backward compatibility
    print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}")
    if args.path_to_checkpoint.endswith(".zip"):
        with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
            with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
                input_state_dict = torch.load(pytorch_dict, map_location="cpu")
    else:
        input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")
    print(f">> keys: {input_state_dict.keys()}")
    ds_args = input_state_dict.get("args", None)
    #print(f">> ds_args: {ds_args}")
    # Read the config, or default to the model released by NVIDIA.
    if args.config_file == "":
        if ds_args is not None:
            if ds_args.bias_gelu_fusion:
                activation_function = "gelu_fast"
            elif ds_args.openai_gelu:
                activation_function = "gelu_new"
            else:
                activation_function = "gelu"
        else:
            # in the very early days this used to be "gelu_new"
            activation_function = "gelu_new"

        # Spell out all parameters in case the defaults change.
        config = GPT2Config(
            vocab_size=50257,
            n_positions=1024,
            n_embd=1024,
            n_layer=24,
            n_head=16,
            n_inner=4096,
            activation_function=activation_function,
            resid_pdrop=0.1,
            embd_pdrop=0.1,
            attn_pdrop=0.1,
            layer_norm_epsilon=1e-5,
            initializer_range=0.02,
            summary_type="cls_index",
            summary_use_proj=True,
            summary_activation=None,
            summary_proj_to_labels=True,
            summary_first_dropout=0.1,
            scale_attn_weights=True,
            use_cache=True,
            bos_token_id=50256,
            eos_token_id=50256,
        )
    else:
        config = GPT2Config.from_json_file(args.config_file)

    config.architectures = ["GPT2LMHeadModel"]

    # Convert.
    print("Converting")
    output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)

    # Print the structure of converted state dict.
    if args.print_checkpoint_structure:
        recursive_print(None, output_state_dict)

    tokenizer_model_name = ""
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name)
    tokenizer_class = type(tokenizer).__name__
    config.tokenizer_class = tokenizer_class

    # Store the config to file.
    print("Saving config")
    #config.save_pretrained(basename)
    print(f">> here is the local converter")
    # Save tokenizer based on args
    print(f"Adding {tokenizer_class} tokenizer files")
    #tokenizer.save_pretrained(basename)

    # Store the state_dict to file.
    output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")
    print(f'Saving checkpoint to "{output_checkpoint_file}"')
    torch.save(output_state_dict, output_checkpoint_file)


####################################################################################################

if __name__ == "__main__":
    main()

####################################################################################################