extract_distilbert.py 4.22 KB
Newer Older
VictorSanh's avatar
VictorSanh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
thomwolf's avatar
thomwolf committed
16
Preprocessing script before training DistilBERT.
VictorSanh's avatar
VictorSanh committed
17
Specific to BERT -> DistilBERT.
VictorSanh's avatar
VictorSanh committed
18
"""
19
20
import argparse

Aymeric Augustin's avatar
Aymeric Augustin committed
21
22
import torch

23
from transformers import BertForMaskedLM
Aymeric Augustin's avatar
Aymeric Augustin committed
24
25


26
27
28
29
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
    )
VictorSanh's avatar
VictorSanh committed
30
    parser.add_argument("--model_type", default="bert", choices=["bert"])
31
32
33
    parser.add_argument("--model_name", default="bert-base-uncased", type=str)
    parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str)
    parser.add_argument("--vocab_transform", action="store_true")
34
35
    args = parser.parse_args()

36
    if args.model_type == "bert":
37
        model = BertForMaskedLM.from_pretrained(args.model_name)
38
        prefix = "bert"
VictorSanh's avatar
VictorSanh committed
39
40
    else:
        raise ValueError(f'args.model_type should be "bert".')
41
42
43
44

    state_dict = model.state_dict()
    compressed_sd = {}

45
46
47
48
    for w in ["word_embeddings", "position_embeddings"]:
        compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"]
    for w in ["weight", "bias"]:
        compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"]
49
50
51

    std_idx = 0
    for teacher_idx in [0, 2, 4, 7, 9, 11]:
52
53
54
55
56
57
58
59
60
61
        for w in ["weight", "bias"]:
            compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[
                f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}"
            ]
            compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[
                f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}"
            ]
            compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[
                f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}"
            ]
62

63
64
65
66
67
68
            compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[
                f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}"
            ]
            compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[
                f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}"
            ]
69

70
71
72
73
74
75
76
77
78
            compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[
                f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}"
            ]
            compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[
                f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}"
            ]
            compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[
                f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}"
            ]
79
80
        std_idx += 1

81
82
    compressed_sd[f"vocab_projector.weight"] = state_dict[f"cls.predictions.decoder.weight"]
    compressed_sd[f"vocab_projector.bias"] = state_dict[f"cls.predictions.bias"]
VictorSanh's avatar
VictorSanh committed
83
    if args.vocab_transform:
84
85
86
        for w in ["weight", "bias"]:
            compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
            compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
87

88
89
    print(f"N layers selected for distillation: {std_idx}")
    print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
90

91
    print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
92
    torch.save(compressed_sd, args.dump_checkpoint)