"examples/seq2seq/finetune_t5.sh" did not exist on "02e5f79662d72cccdca81a47e3001a5f6d36e5b1"
extract.py 4.35 KB
Newer Older
VictorSanh's avatar
VictorSanh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocessing script before training the distilled model.
Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2.
"""
import argparse

Aymeric Augustin's avatar
Aymeric Augustin committed
21
22
import torch

23
from transformers import GPT2LMHeadModel, RobertaForMaskedLM
Aymeric Augustin's avatar
Aymeric Augustin committed
24
25


26
27
28
29
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
    )
VictorSanh's avatar
VictorSanh committed
30
    parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
31
32
33
    parser.add_argument("--model_name", default="roberta-large", type=str)
    parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str)
    parser.add_argument("--vocab_transform", action="store_true")
VictorSanh's avatar
VictorSanh committed
34
35
    args = parser.parse_args()

36
    if args.model_type == "roberta":
VictorSanh's avatar
VictorSanh committed
37
        model = RobertaForMaskedLM.from_pretrained(args.model_name)
38
39
        prefix = "roberta"
    elif args.model_type == "gpt2":
VictorSanh's avatar
VictorSanh committed
40
        model = GPT2LMHeadModel.from_pretrained(args.model_name)
41
        prefix = "transformer"
VictorSanh's avatar
VictorSanh committed
42
43
44
45

    state_dict = model.state_dict()
    compressed_sd = {}

46
    # Embeddings #
47
48
49
    if args.model_type == "gpt2":
        for param_name in ["wte.weight", "wpe.weight"]:
            compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
VictorSanh's avatar
VictorSanh committed
50
    else:
51
52
        for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]:
            param_name = f"{prefix}.embeddings.{w}.weight"
VictorSanh's avatar
VictorSanh committed
53
            compressed_sd[param_name] = state_dict[param_name]
54
55
        for w in ["weight", "bias"]:
            param_name = f"{prefix}.embeddings.LayerNorm.{w}"
VictorSanh's avatar
VictorSanh committed
56
57
            compressed_sd[param_name] = state_dict[param_name]

58
    # Transformer Blocks #
VictorSanh's avatar
VictorSanh committed
59
60
    std_idx = 0
    for teacher_idx in [0, 2, 4, 7, 9, 11]:
61
62
63
64
65
66
67
        if args.model_type == "gpt2":
            for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]:
                for w in ["weight", "bias"]:
                    compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[
                        f"{prefix}.h.{teacher_idx}.{layer}.{w}"
                    ]
            compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"]
VictorSanh's avatar
VictorSanh committed
68
        else:
69
70
71
72
73
74
75
76
77
78
79
80
81
82
            for layer in [
                "attention.self.query",
                "attention.self.key",
                "attention.self.value",
                "attention.output.dense",
                "attention.output.LayerNorm",
                "intermediate.dense",
                "output.dense",
                "output.LayerNorm",
            ]:
                for w in ["weight", "bias"]:
                    compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[
                        f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}"
                    ]
VictorSanh's avatar
VictorSanh committed
83
84
        std_idx += 1

85
    # Language Modeling Head ###s
86
87
88
    if args.model_type == "roberta":
        for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
            compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
VictorSanh's avatar
VictorSanh committed
89
        if args.vocab_transform:
90
91
92
93
94
95
96
            for w in ["weight", "bias"]:
                compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"]
                compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"]
    elif args.model_type == "gpt2":
        for w in ["weight", "bias"]:
            compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
        compressed_sd[f"lm_head.weight"] = state_dict[f"lm_head.weight"]
VictorSanh's avatar
VictorSanh committed
97

98
99
    print(f"N layers selected for distillation: {std_idx}")
    print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
VictorSanh's avatar
VictorSanh committed
100

101
    print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
VictorSanh's avatar
VictorSanh committed
102
    torch.save(compressed_sd, args.dump_checkpoint)