bertology.py 10.7 KB
Newer Older
1
#!/usr/bin/env python3
thomwolf's avatar
thomwolf committed
2
import os
3
4
import argparse
import logging
thomwolf's avatar
thomwolf committed
5
from tqdm import tqdm
6
7
8

import numpy as np

thomwolf's avatar
thomwolf committed
9
10
11
12
13
import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset, Subset
from torch.utils.data.distributed import DistributedSampler
from torch.nn import CrossEntropyLoss, MSELoss

thomwolf's avatar
thomwolf committed
14
from pytorch_pretrained_bert import BertForSequenceClassification, BertTokenizer
15

thomwolf's avatar
thomwolf committed
16
17
18
from run_classifier_dataset_utils import processors, output_modes, convert_examples_to_features, compute_metrics


19
20
logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

def entropy(p):
    plogp = p * torch.log(p)
    plogp[p == 0] = 0
    return -plogp.sum(dim=-1)

def print_1d_tensor(tensor, prefix=""):
    if tensor.dtype != torch.long:
        logger.info(prefix + "\t".join(f"{x:.5f}" for x in tensor.cpu().data))
    else:
        logger.info(prefix + "\t".join(f"{x:d}" for x in tensor.cpu().data))

def print_2d_tensor(tensor):
    logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
    for row in range(len(tensor)):
        print_1d_tensor(tensor[row], prefix=f"layer {row + 1}:\t")

def compute_heads_importance(args, model, eval_dataloader):
    """ Example on how to use model outputs to compute:
        - head attention entropy (activated by setting output_attentions=True when we created the model
        - head importance scores according to http://arxiv.org/abs/1905.10650
            (activated by setting keep_multihead_output=True when we created the model)
    """
    for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
        batch = tuple(t.to(args.device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids = batch

        # Do a forward pass
        all_attentions, logits = model(input_ids, segment_ids, input_mask)

        # Update head attention entropy
        for layer, attn in enumerate(all_attentions):
            masked_entropy = entropy(attn.detach()) * input_mask.float().unsqueeze(1)
            attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()

        # Update head importance scores with regards to our loss
        # First backpropagate to populate the gradients
        if output_mode == "classification":
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
        elif output_mode == "regression":
            loss_fct = MSELoss()
            loss = loss_fct(logits.view(-1), label_ids.view(-1))
        loss.backward()
        # Second compute importance scores according to http://arxiv.org/abs/1905.10650
        multihead_outputs = model.bert.get_multihead_outputs()
        for layer, mh_layer_output in enumerate(multihead_outputs):
            dot = torch.einsum("bhli,bhli->bhl", [mh_layer_output.grad, mh_layer_output])
            head_importance[layer] += dot.abs().sum(-1).sum(0).detach()

        tot_tokens += input_mask.float().detach().sum().data

    # Normalize
    attn_entropy /= tot_tokens
    head_importance /= tot_tokens
    if args.normalize_importance:
        head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())

    return attn_entropy, head_importance

81
82
def run_model():
    parser = argparse.ArgumentParser()
thomwolf's avatar
thomwolf committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    parser.add_argument('--model_name_or_path', type=str, default='bert-base-cased-finetuned-mrpc', help='pretrained model name or path to local checkpoint')
    parser.add_argument("--task_name", type=str, default='mrpc', help="The name of the task to train.")
    parser.add_argument("--data_dir", type=str, required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--output_dir", type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances.")
    parser.add_argument("--overwrite_output_dir", action='store_true', help="Whether to overwrite data in output directory")

    parser.add_argument("--normalize_importance", action='store_true', help="Whether to normalize importance score between 0 and 1")

    parser.add_argument("--try_pruning", action='store_true', help="Whether to try to prune head until a threshold of accuracy.")
    parser.add_argument("--pruning_threshold", default=0.9, type=float, help="Pruning threshold of accuracy.")

    parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")

100
    parser.add_argument("--seed", type=int, default=42)
thomwolf's avatar
thomwolf committed
101
102
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available")
103
104
    args = parser.parse_args()

thomwolf's avatar
thomwolf committed
105
    # Setup devices and distributed training
thomwolf's avatar
thomwolf committed
106
107
108
109
110
111
112
    if args.local_rank == -1 or args.no_cuda:
        args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        n_gpu = 1
thomwolf's avatar
thomwolf committed
113
        torch.distributed.init_process_group(backend='nccl')  # Initializes the distributed backend
thomwolf's avatar
thomwolf committed
114

thomwolf's avatar
thomwolf committed
115
    # Setup logging
thomwolf's avatar
thomwolf committed
116
    logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
thomwolf's avatar
thomwolf committed
117
    logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, n_gpu, bool(args.local_rank != -1)))
118

thomwolf's avatar
thomwolf committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    # Set seeds
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed(args.seed)

    # Prepare GLUE task
    task_name = args.task_name.lower()
    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # Prepare output directory
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    # Load model & tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # Make sure only one distributed process download model & vocab
141
    tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
thomwolf's avatar
thomwolf committed
142
143
144
145
146
147
148
149
150
151
152

    # Load a model with all BERTology options on:
    #   output_attentions => will output attention weights
    #   keep_multihead_output => will store gradient of attention head outputs for head importance computation
    #       see: http://arxiv.org/abs/1905.10650
    model = BertForSequenceClassification.from_pretrained(args.model_name_or_path,
                                                          num_labels=num_labels,
                                                          output_attentions=True,
                                                          keep_multihead_output=True)
    if args.local_rank == 0:
        torch.distributed.barrier()  # Make sure only one distributed process download model & vocab
thomwolf's avatar
thomwolf committed
153
    model.to(args.device)
thomwolf's avatar
thomwolf committed
154
155
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
156
157
    model.eval()

thomwolf's avatar
thomwolf committed
158
159
160
161
162
163
164
165
166
167
168
    # Prepare dataset for the GLUE task
    eval_examples = processor.get_dev_examples(args.data_dir)
    cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format(
        list(filter(None, args.model_name_or_path.split('/'))).pop(), str(args.max_seq_length), str(task_name)))
    try:
        eval_features = torch.load(cached_eval_features_file)
    except:
        eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
        if args.local_rank in [-1, 0]:
            logger.info("Saving eval features to cache file %s", cached_eval_features_file)
            torch.save(eval_features, cached_eval_features_file)
169

thomwolf's avatar
thomwolf committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long if output_mode == "classification" else torch.float)
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

    if args.data_subset > 0:
        eval_data = Subset(eval_data, list(range(args.data_subset)))

    eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)

    # Print/save training arguments
    print(args)
    torch.save(args, os.path.join(args.output_dir, 'run_args.bin'))

    # To showcase some BERTology methods, we will compute:
    #   - the average entropy of each head over the dev set
    #   - the importance score of each head over the dev set as explained in http://arxiv.org/abs/1905.10650
    n_layers, n_heads = model.bert.config.num_hidden_layers, model.bert.config.num_attention_heads
    head_importance = torch.zeros(n_layers, n_heads).to(args.device)
    attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
    tot_tokens = 0.0
193

thomwolf's avatar
thomwolf committed
194
195
    # Compute head entropy and importance score
    attn_entropy, head_importance = compute_heads_importance(args, model, eval_dataloader)
196

thomwolf's avatar
thomwolf committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    # Print/save matrices
    np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy)
    np.save(os.path.join(args.output_dir, 'head_importance.npy'), head_importance)

    logger.info("Attention entropies")
    print_2d_tensor(attn_entropy)
    logger.info("Head importance scores")
    print_2d_tensor(head_importance)
    logger.info("Head ranked by importance scores")
    head_ranks = torch.zeros(n_layers * n_heads, dtype=torch.long, device=args.device)
    head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel())
    print_2d_tensor(head_ranks.view_as(head_importance))

    # Do pruning if we want to
    if args.try_pruning and args.pruning_threshold > 0.0 and args.pruning_threshold < 1.0:
        
        

if __name__ == '__main__':
    run_model()