bertology.py 16.2 KB
Newer Older
1
#!/usr/bin/env python3
thomwolf's avatar
thomwolf committed
2
import os
3
4
import argparse
import logging
thomwolf's avatar
thomwolf committed
5
from datetime import timedelta, datetime
thomwolf's avatar
thomwolf committed
6
from tqdm import tqdm
7
8
9

import numpy as np

thomwolf's avatar
thomwolf committed
10
11
12
13
14
import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset, Subset
from torch.utils.data.distributed import DistributedSampler
from torch.nn import CrossEntropyLoss, MSELoss

thomwolf's avatar
thomwolf committed
15
from pytorch_pretrained_bert import BertForSequenceClassification, BertTokenizer
16

thomwolf's avatar
thomwolf committed
17
18
19
from run_classifier_dataset_utils import processors, output_modes, convert_examples_to_features, compute_metrics


20
21
logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

def entropy(p):
    plogp = p * torch.log(p)
    plogp[p == 0] = 0
    return -plogp.sum(dim=-1)

def print_1d_tensor(tensor, prefix=""):
    if tensor.dtype != torch.long:
        logger.info(prefix + "\t".join(f"{x:.5f}" for x in tensor.cpu().data))
    else:
        logger.info(prefix + "\t".join(f"{x:d}" for x in tensor.cpu().data))

def print_2d_tensor(tensor):
    logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
    for row in range(len(tensor)):
        print_1d_tensor(tensor[row], prefix=f"layer {row + 1}:\t")

thomwolf's avatar
thomwolf committed
39
def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None):
thomwolf's avatar
thomwolf committed
40
41
42
43
44
    """ Example on how to use model outputs to compute:
        - head attention entropy (activated by setting output_attentions=True when we created the model
        - head importance scores according to http://arxiv.org/abs/1905.10650
            (activated by setting keep_multihead_output=True when we created the model)
    """
thomwolf's avatar
thomwolf committed
45
46
47
48
49
50
51
52
    # Prepare our tensors
    n_layers, n_heads = model.bert.config.num_hidden_layers, model.bert.config.num_attention_heads
    head_importance = torch.zeros(n_layers, n_heads).to(args.device)
    attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
    preds = None
    labels = None
    tot_tokens = 0.0

thomwolf's avatar
thomwolf committed
53
54
55
56
    for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
        batch = tuple(t.to(args.device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids = batch

thomwolf's avatar
thomwolf committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        # Do a forward pass (not in torch.no_grad() since we need gradients for importance score - see below)
        all_attentions, logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, head_mask=head_mask)

        if compute_entropy:
            # Update head attention entropy
            for layer, attn in enumerate(all_attentions):
                masked_entropy = entropy(attn.detach()) * input_mask.float().unsqueeze(1)
                attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()

        if compute_importance:
            # Update head importance scores with regards to our loss
            # First, backpropagate to populate the gradients
            if args.output_mode == "classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, args.num_labels), label_ids.view(-1))
            elif args.output_mode == "regression":
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), label_ids.view(-1))
            loss.backward()
            # Second, compute importance scores according to http://arxiv.org/abs/1905.10650
            multihead_outputs = model.bert.get_multihead_outputs()
            for layer, mh_layer_output in enumerate(multihead_outputs):
                dot = torch.einsum("bhli,bhli->bhl", [mh_layer_output.grad, mh_layer_output])
                head_importance[layer] += dot.abs().sum(-1).sum(0).detach()

        # Also store our logits/labels if we want to compute metrics afterwards
        if preds is None:
            preds = logits.detach().cpu().numpy()
            labels = label_ids.detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            labels = np.append(labels, label_ids.detach().cpu().numpy(), axis=0)
thomwolf's avatar
thomwolf committed
89
90
91
92
93
94

        tot_tokens += input_mask.float().detach().sum().data

    # Normalize
    attn_entropy /= tot_tokens
    head_importance /= tot_tokens
thomwolf's avatar
thomwolf committed
95
96
97
98
99
100
101
    # Layerwise importance normalization
    if not args.dont_normalize_importance_by_layer:
        exponent = 2
        norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1/exponent)
        head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20

    if not args.dont_normalize_global_importance:
thomwolf's avatar
thomwolf committed
102
103
        head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())

thomwolf's avatar
thomwolf committed
104
    return attn_entropy, head_importance, preds, labels
thomwolf's avatar
thomwolf committed
105

106
107
def run_model():
    parser = argparse.ArgumentParser()
thomwolf's avatar
thomwolf committed
108
109
110
111
112
113
114
    parser.add_argument('--model_name_or_path', type=str, default='bert-base-cased-finetuned-mrpc', help='pretrained model name or path to local checkpoint')
    parser.add_argument("--task_name", type=str, default='mrpc', help="The name of the task to train.")
    parser.add_argument("--data_dir", type=str, required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--output_dir", type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances.")
    parser.add_argument("--overwrite_output_dir", action='store_true', help="Whether to overwrite data in output directory")

thomwolf's avatar
thomwolf committed
115
116
    parser.add_argument("--dont_normalize_importance_by_layer", action='store_true', help="Don't normalize importance score by layers")
    parser.add_argument("--dont_normalize_global_importance", action='store_true', help="Don't normalize all importance scores between 0 and 1")
thomwolf's avatar
thomwolf committed
117

thomwolf's avatar
thomwolf committed
118
119
120
121
122
    parser.add_argument("--try_masking", action='store_true', help="Whether to try to mask head until a threshold of accuracy.")
    parser.add_argument("--masking_threshold", default=0.9, type=float, help="masking threshold in term of metrics"
                                                                             "(stop masking when metric < threshold * original metric value).")
    parser.add_argument("--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step.")
    parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")
thomwolf's avatar
thomwolf committed
123
124
125
126
127
128

    parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")

129
    parser.add_argument("--seed", type=int, default=42)
thomwolf's avatar
thomwolf committed
130
131
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available")
thomwolf's avatar
thomwolf committed
132
133
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
134
135
    args = parser.parse_args()

thomwolf's avatar
thomwolf committed
136
137
138
139
140
141
142
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
143
    # Setup devices and distributed training
thomwolf's avatar
thomwolf committed
144
145
146
147
148
149
150
    if args.local_rank == -1 or args.no_cuda:
        args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        n_gpu = 1
thomwolf's avatar
thomwolf committed
151
        torch.distributed.init_process_group(backend='nccl')  # Initializes the distributed backend
thomwolf's avatar
thomwolf committed
152

thomwolf's avatar
thomwolf committed
153
    # Setup logging
thomwolf's avatar
thomwolf committed
154
    logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
thomwolf's avatar
thomwolf committed
155
    logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, n_gpu, bool(args.local_rank != -1)))
156

thomwolf's avatar
thomwolf committed
157
158
159
160
161
162
163
164
165
166
    # Set seeds
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed(args.seed)

    # Prepare GLUE task
    task_name = args.task_name.lower()
    processor = processors[task_name]()
    label_list = processor.get_labels()
thomwolf's avatar
thomwolf committed
167
168
    args.output_mode = output_modes[task_name]
    args.num_labels = len(label_list)
thomwolf's avatar
thomwolf committed
169
170
171
172
173
174
175
176
177
178

    # Prepare output directory
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    # Load model & tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # Make sure only one distributed process download model & vocab
179
    tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
thomwolf's avatar
thomwolf committed
180
181
182
183
184
185

    # Load a model with all BERTology options on:
    #   output_attentions => will output attention weights
    #   keep_multihead_output => will store gradient of attention head outputs for head importance computation
    #       see: http://arxiv.org/abs/1905.10650
    model = BertForSequenceClassification.from_pretrained(args.model_name_or_path,
thomwolf's avatar
thomwolf committed
186
                                                          num_labels=args.num_labels,
thomwolf's avatar
thomwolf committed
187
188
189
190
                                                          output_attentions=True,
                                                          keep_multihead_output=True)
    if args.local_rank == 0:
        torch.distributed.barrier()  # Make sure only one distributed process download model & vocab
thomwolf's avatar
thomwolf committed
191
    model.to(args.device)
thomwolf's avatar
thomwolf committed
192
193
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
194
195
    model.eval()

thomwolf's avatar
thomwolf committed
196
197
198
199
200
201
202
    # Prepare dataset for the GLUE task
    eval_examples = processor.get_dev_examples(args.data_dir)
    cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format(
        list(filter(None, args.model_name_or_path.split('/'))).pop(), str(args.max_seq_length), str(task_name)))
    try:
        eval_features = torch.load(cached_eval_features_file)
    except:
thomwolf's avatar
thomwolf committed
203
        eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, args.output_mode)
thomwolf's avatar
thomwolf committed
204
205
206
        if args.local_rank in [-1, 0]:
            logger.info("Saving eval features to cache file %s", cached_eval_features_file)
            torch.save(eval_features, cached_eval_features_file)
207

thomwolf's avatar
thomwolf committed
208
209
210
    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
thomwolf's avatar
thomwolf committed
211
    all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long if args.output_mode == "classification" else torch.float)
thomwolf's avatar
thomwolf committed
212
213
214
215
216
217
218
219
220
221
222
223
224
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

    if args.data_subset > 0:
        eval_data = Subset(eval_data, list(range(args.data_subset)))

    eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)

    # Print/save training arguments
    print(args)
    torch.save(args, os.path.join(args.output_dir, 'run_args.bin'))

    # Compute head entropy and importance score
thomwolf's avatar
thomwolf committed
225
    attn_entropy, head_importance, _, _ = compute_heads_importance(args, model, eval_dataloader)
226

thomwolf's avatar
thomwolf committed
227
    # Print/save matrices
thomwolf's avatar
thomwolf committed
228
229
    np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy.detach().cpu().numpy())
    np.save(os.path.join(args.output_dir, 'head_importance.npy'), head_importance.detach().cpu().numpy())
thomwolf's avatar
thomwolf committed
230
231
232
233
234
235

    logger.info("Attention entropies")
    print_2d_tensor(attn_entropy)
    logger.info("Head importance scores")
    print_2d_tensor(head_importance)
    logger.info("Head ranked by importance scores")
thomwolf's avatar
thomwolf committed
236
    head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
thomwolf's avatar
thomwolf committed
237
    head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel(), device=args.device)
thomwolf's avatar
thomwolf committed
238
239
240
241
242
243
244
245
    head_ranks = head_ranks.view_as(head_importance)
    print_2d_tensor(head_ranks)

    # Do masking if we want to
    if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
        _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
        original_score = compute_metrics(task_name, preds, labels)[args.metric_name]
thomwolf's avatar
thomwolf committed
246
        logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
thomwolf's avatar
thomwolf committed
247
248
249
250
251
252

        new_head_mask = torch.ones_like(head_importance)
        num_to_mask = int(new_head_mask.numel() * args.masking_amount)

        current_score = original_score
        while current_score >= original_score * args.masking_threshold:
thomwolf's avatar
thomwolf committed
253
254
255
256
            head_mask = new_head_mask  # save current head mask
            # heads from most important to least - keep only not-masked heads
            head_importance = head_importance.view(-1)[head_mask.view(-1).nonzero()][:, 0]
            current_heads_to_mask = head_importance.sort()[1]
thomwolf's avatar
thomwolf committed
257

thomwolf's avatar
thomwolf committed
258
            if len(current_heads_to_mask) <= num_to_mask:
thomwolf's avatar
thomwolf committed
259
260
261
                break

            # mask heads
thomwolf's avatar
thomwolf committed
262
263
            current_heads_to_mask = current_heads_to_mask[:num_to_mask]
            logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
thomwolf's avatar
thomwolf committed
264
            new_head_mask = head_mask.view(-1)
thomwolf's avatar
thomwolf committed
265
266
            new_head_mask[current_heads_to_mask] = 0.0
            new_head_mask = new_head_mask.view_as(head_mask)
thomwolf's avatar
thomwolf committed
267
268
269
270
271
272
            print_2d_tensor(new_head_mask)

            # Compute metric and head importance again
            _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask)
            preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
            current_score = compute_metrics(task_name, preds, labels)[args.metric_name]
thomwolf's avatar
thomwolf committed
273
            logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100)
thomwolf's avatar
thomwolf committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

        # Try pruning and test time speedup
        # Pruning is like masking but we actually remove the masked weights
        before_time = datetime.now()
        _, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
                                                       compute_entropy=False, compute_importance=False, head_mask=head_mask)
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
        score_masking = compute_metrics(task_name, preds, labels)[args.metric_name]
        original_time = datetime.now() - before_time

        heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask)))
        assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
        model.bert.prune_heads(heads_to_prune)

        before_time = datetime.now()
        _, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
                                                       compute_entropy=False, compute_importance=False, head_mask=None)
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
        score_pruning = compute_metrics(task_name, preds, labels)[args.metric_name]
        new_time = datetime.now() - before_time

        logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
        logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time/new_time * 100)
thomwolf's avatar
thomwolf committed
297
298
299

if __name__ == '__main__':
    run_model()