convert_bert_ckpt_to_deepspeed.py 14.1 KB
Newer Older
Pan,Huiwen's avatar
Pan,Huiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
# coding=utf-8
# This script references to below file from HuggingFace:
#   https://github.com/huggingface/transformers/blob/d541938/src/transformers/modeling_bert.py
#
# It converts Tensorflow and Huggingface checkpoint files to DeepSpeed.

import os
import argparse
import logging
import torch
import re
import numpy as np

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def set_data(param, array):
    try:
        assert param.shape == array.shape
    except AssertionError as e:
        e.args += (param.shape, array.shape)
        raise
    param.data = torch.from_numpy(array)

def load_tf_weights_in_bert_kernel(model, ckpt_path, voc_size_diff):
    """ Load tf checkpoints in DeepSpeed model.
    """
    try:
        import re
        import numpy as np
        import tensorflow as tf
    except ImportError:
        logger.error(
            "Loading a TensorFlow model in DeepSpeed, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
        raise
    tf_path = os.path.abspath(ckpt_path)
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

    qkv = {}
    for name_str, array in zip(names, arrays):
        name = name_str.split("/")
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
            logger.info("Skipping {}".format("/".join(name)))
            continue
        pointer = model
        key = None
        skipping = False
        for m_name in name:
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
                scope_names = re.split(r"_(\d+)", m_name)
            else:
                scope_names = [m_name]

            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "output_weights":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
            # Special in deepspeed.
            elif name_str.find("bert/pooler/dense") >= 0 and scope_names[0] == "dense":
                pointer = getattr(pointer, "dense_act")
            elif name_str.find("bert/embeddings/LayerNorm/gamma") >= 0 and scope_names[0] == "gamma":
                pointer = getattr(pointer, "weight")
            elif name_str.find("bert/embeddings/LayerNorm/beta") >= 0 and scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            else:
                try:
                    pointer = getattr(pointer, scope_names[0])
                except AttributeError:
                    logger.info("Skipping {}".format("/".join(name)))
                    skipping = True
                    break

            if len(scope_names) >= 2:
                num = int(scope_names[1])

                pointer = pointer[num]

                # For transofrmer kernel layers.
                if scope_names[0] == 'layer':
                    if name_str.find("attention/self/query/kernel") > 0:
                        key = "qw"
                    elif name_str.find("attention/self/query/bias") > 0:
                        key = "qb"
                    elif name_str.find("attention/self/key/kernel") > 0:
                        key = "kw"
                    elif name_str.find("attention/self/key/bias") > 0:
                        key = "kb"
                    elif name_str.find("attention/self/value/kernel") > 0:
                        key = "vw"
                    elif name_str.find("attention/self/value/bias") > 0:
                        key = "vb"
                    elif name_str.find("attention/output/dense/kernel") > 0:
                        pointer = getattr(pointer, "attn_ow")
                    elif name_str.find("attention/output/dense/bias") > 0:
                        pointer = getattr(pointer, "attn_ob")
                    elif name_str.find("attention/output/LayerNorm/gamma") > 0:
                        pointer = getattr(pointer, "attn_nw")
                    elif name_str.find("attention/output/LayerNorm/beta") > 0:
                        pointer = getattr(pointer, "attn_nb")
                    elif name_str.find("intermediate/dense/kernel") > 0:
                        pointer = getattr(pointer, "inter_w")
                    elif name_str.find("intermediate/dense/bias") > 0:
                        pointer = getattr(pointer, "inter_b")
                    elif name_str.find("output/dense/kernel") > 0 and name_str.find("attention") < 0:
                        pointer = getattr(pointer, "output_w")
                    elif name_str.find("output/dense/bias") > 0 and name_str.find("attention") < 0:
                        pointer = getattr(pointer, "output_b")
                    elif name_str.find("output/LayerNorm/gamma") > 0 and name_str.find("attention") < 0:
                        pointer = getattr(pointer, "norm_w")
                    elif name_str.find("output/LayerNorm/beta") > 0 and name_str.find("attention") < 0:
                        pointer = getattr(pointer, "norm_b")
                    else:
                        raise ValueError(f"unexpect scope name {name_str} in transformer layer.")
                    break

        if skipping:
            continue

        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif "kernel" in name:
            array = np.transpose(array)

        if key is not None:
            qkv[key] = array

        if all(k in qkv for k in ("qw", "kw", "vw")):
            array = np.concatenate((qkv["qw"], qkv["kw"], qkv["vw"]), axis=0)
            pointer = getattr(pointer, "attn_qkvw")
            qkv.pop("qw")
            qkv.pop("kw")
            qkv.pop("vw")
        elif all(k in qkv for k in ("qb", "kb", "vb")):
            array = np.concatenate((qkv["qb"], qkv["kb"], qkv["vb"]), axis=0)
            pointer = getattr(pointer, "attn_qkvb")
            qkv.pop("qb")
            qkv.pop("kb")
            qkv.pop("vb")
        elif key is not None:
            # For Q/K/V weight/bias in TF, do nothing if not all ready to merge.
            continue

        # DeepSpeed BERT model has voc_size 8 aligned.
        if voc_size_diff > 0 and name_str.find("embeddings/word_embeddings") >= 0:
            z = np.zeros((voc_size_diff, array.shape[1]), dtype=array.dtype)
            array = np.concatenate((array, z), axis=0)

        set_data(pointer, array)
        logger.info("Initialize DeepSpeed weight {}".format(name))

    return model

def load_hf_weights_in_bert_kernel(model, ckpt_path, voc_size_diff):
    """ Load huggingface checkpoints and convert to a deepspeed model.
    """
    hf_path = os.path.abspath(ckpt_path)
    logger.info("Converting Huggingface checkpoint from {}".format(hf_path))
    # Load weights from Huggingface model
    ckpt = torch.load(hf_path, map_location=torch.device("cpu"))

    qkv = {}
    for name_str in ckpt.keys():
        array = ckpt[name_str].numpy()
        logger.info("Loading Huggingface weight {} with shape {}".format(name_str, array.shape))
        name = name_str.split(".")
        pointer = model
        key = None
        is_layer = False
        skipping = False
        for m_name in name:
            # Special in deepspeed.
            if name_str.find("bert.pooler.dense") >= 0 and m_name == "dense":
                pointer = getattr(pointer, "dense_act")
            elif is_layer:
                pass
            else:
                try:
                    pointer = getattr(pointer, m_name)
                except AttributeError:
                    logger.info("Skipping {}".format(".".join(name)))
                    skipping = True
                    break

            if m_name == "layer":
                is_layer = True
                continue

            if m_name.isnumeric() and is_layer:
                num = int(m_name)
                pointer = pointer[num]
                is_layer = False

                # For transofrmer kernel layers.
                if name_str.find("attention.self.query.weight") > 0:
                    key = "qw"
                elif name_str.find("attention.self.query.bias") > 0:
                    key = "qb"
                elif name_str.find("attention.self.key.weight") > 0:
                    key = "kw"
                elif name_str.find("attention.self.key.bias") > 0:
                    key = "kb"
                elif name_str.find("attention.self.value.weight") > 0:
                    key = "vw"
                elif name_str.find("attention.self.value.bias") > 0:
                    key = "vb"
                elif name_str.find("attention.output.dense.weight") > 0:
                    pointer = getattr(pointer, "attn_ow")
                elif name_str.find("attention.output.dense.bias") > 0:
                    pointer = getattr(pointer, "attn_ob")
                elif name_str.find("attention.output.LayerNorm.weight") > 0:
                    pointer = getattr(pointer, "attn_nw")
                elif name_str.find("attention.output.LayerNorm.bias") > 0:
                    pointer = getattr(pointer, "attn_nb")
                elif name_str.find("intermediate.dense.weight") > 0:
                    pointer = getattr(pointer, "inter_w")
                elif name_str.find("intermediate.dense.bias") > 0:
                    pointer = getattr(pointer, "inter_b")
                elif name_str.find("output.dense.weight") > 0 and name_str.find("attention") < 0:
                    pointer = getattr(pointer, "output_w")
                elif name_str.find("output.dense.bias") > 0 and name_str.find("attention") < 0:
                    pointer = getattr(pointer, "output_b")
                elif name_str.find("output.LayerNorm.weight") > 0 and name_str.find("attention") < 0:
                    pointer = getattr(pointer, "norm_w")
                elif name_str.find("output.LayerNorm.bias") > 0 and name_str.find("attention") < 0:
                    pointer = getattr(pointer, "norm_b")
                else:
                    raise ValueError(f"unexpect scope name {name_str} in transformer layer.")
                break

        if skipping:
            continue

        if key is not None:
            qkv[key] = array

        if all(k in qkv for k in ("qw", "kw", "vw")):
            array = np.concatenate((qkv["qw"], qkv["kw"], qkv["vw"]), axis=0)
            pointer = getattr(pointer, "attn_qkvw")
            qkv.pop("qw")
            qkv.pop("kw")
            qkv.pop("vw")
        elif all(k in qkv for k in ("qb", "kb", "vb")):
            array = np.concatenate((qkv["qb"], qkv["kb"], qkv["vb"]), axis=0)
            pointer = getattr(pointer, "attn_qkvb")
            qkv.pop("qb")
            qkv.pop("kb")
            qkv.pop("vb")
        elif key is not None:
            # For Q/K/V weight/bias in HF, do nothing if not all ready to merge.
            continue

        # DeepSpeed BERT model has voc_size 8 aligned.
        if voc_size_diff > 0 and name_str.find("embeddings.word_embeddings") >= 0:
            z = np.zeros((voc_size_diff, array.shape[1]), dtype=array.dtype)
            array = np.concatenate((array, z), axis=0)

        set_data(pointer, array)
        logger.info("Initialize DeepSpeed weight {}".format(name))

    return model

def load_hf_weights_in_bert_torch(model, ckpt_path, voc_size_diff):
    """ Load huggingface checkpoints and convert to a deepspeed model.
    """
    hf_path = os.path.abspath(ckpt_path)
    logger.info("Converting Huggingface checkpoint from {}".format(hf_path))
    # Load weights from Huggingface model
    ckpt = torch.load(hf_path, map_location=torch.device("cpu"))

    qkv = {}
    for name_str in ckpt.keys():
        array = ckpt[name_str].numpy()
        logger.info("Loading Huggingface weight {} with shape {}".format(name_str, array.shape))
        name = name_str.split(".")
        pointer = model
        key = None
        is_layer = False
        skipping = False
        for m_name in name:
            # Special in deepspeed.
            if name_str.find("intermediate.dense") >= 0 and m_name == "dense":
                pointer = getattr(pointer, "dense_act")
            elif name_str.find("pooler.dense") >= 0 and m_name == "dense":
                pointer = getattr(pointer, "dense_act")
            else:
                try:
                    pointer = getattr(pointer, m_name)
                except AttributeError:
                    logger.info("Skipping {}".format(".".join(name)))
                    skipping = True
                    break

        if skipping:
            continue

        # DeepSpeed BERT model has voc_size 8 aligned.
        if voc_size_diff > 0 and name_str.find("embeddings.word_embeddings") >= 0:
            z = np.zeros((voc_size_diff, array.shape[1]), dtype=array.dtype)
            array = np.concatenate((array, z), axis=0)

        set_data(pointer, array)
        logger.info("Initialize DeepSpeed weight {}".format(name))

    return model

def convert_ckpt_to_deepspeed(model, ckpt_type, ckpt_path, vocab_diff, kernel_enabled):

    # Load weights from checkpoint
    if ckpt_type == "HF":
        if kernel_enabled:
            load_hf_weights_in_bert_kernel(model, ckpt_path, vocab_diff)
        else:
            load_hf_weights_in_bert_torch(model, ckpt_path, vocab_diff)
    elif ckpt_type == "TF":
        if kernel_enabled:
            load_tf_weights_in_bert_kernel(model, ckpt_path, vocab_diff)
        else:
            raise ValueError("--deepspeed_transformer_kernel is required for loading TF checkpoint.")
    else:
        raise ValueError(f"Invalid ckpt_type.")