extract_for_distil.py 3.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from pytorch_transformers import BertForPreTraining
import torch
import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForPreTraining for Transfer Learned Distillation")
    parser.add_argument("--bert_model", default='bert-base-uncased', type=str)
    parser.add_argument("--dump_checkpoint", default='serialization_dir/transfer_learning_checkpoint_0247911.pth', type=str)
    parser.add_argument("--vocab_transform", action='store_true')
    args = parser.parse_args()


    model = BertForPreTraining.from_pretrained(args.bert_model)

    state_dict = model.state_dict()
    compressed_sd = {}

    for w in ['word_embeddings', 'position_embeddings']:
        compressed_sd[f'dilbert.embeddings.{w}.weight'] = \
            state_dict[f'bert.embeddings.{w}.weight']
    for w in ['weight', 'bias']:
        compressed_sd[f'dilbert.embeddings.LayerNorm.{w}'] = \
            state_dict[f'bert.embeddings.LayerNorm.{w}']

    std_idx = 0
    for teacher_idx in [0, 2, 4, 7, 9, 11]:
        for w in ['weight', 'bias']:
            compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \
                state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.query.{w}']
            compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \
                state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.key.{w}']
            compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \
                state_dict[f'bert.encoder.layer.{teacher_idx}.attention.self.value.{w}']

            compressed_sd[f'dilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \
                state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.dense.{w}']
            compressed_sd[f'dilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \
                state_dict[f'bert.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}']

            compressed_sd[f'dilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \
                state_dict[f'bert.encoder.layer.{teacher_idx}.intermediate.dense.{w}']
            compressed_sd[f'dilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \
                state_dict[f'bert.encoder.layer.{teacher_idx}.output.dense.{w}']
            compressed_sd[f'dilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \
                state_dict[f'bert.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
        std_idx += 1

    compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
    compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
    if args.vocab_transform:
        for w in ['weight', 'bias']:
            compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
            compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']

    print(f'N layers selected for distillation: {std_idx}')
    print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')

    print(f'Save transfered checkpoint to {args.dump_checkpoint}.')
    torch.save(compressed_sd, args.dump_checkpoint)