run_bertology.py 18.2 KB
Newer Older
1
#!/usr/bin/env python3
thomwolf's avatar
thomwolf committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2018 CMU and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Bertology: this script shows how you can explore the internals of the models in the library to:
    - compute the entropy of the head attentions
    - compute the importance of each head
    - prune (remove) the low importance head.
    Some parts of this script are adapted from the code of Michel et al. (http://arxiv.org/abs/1905.10650)
    which is available at https://github.com/pmichel31415/are-16-heads-really-better-than-1
"""
22
23
import argparse
import logging
Aymeric Augustin's avatar
Aymeric Augustin committed
24
import os
25
from datetime import datetime
26
27

import numpy as np
thomwolf's avatar
thomwolf committed
28
import torch
29
from torch.utils.data import DataLoader, SequentialSampler, Subset
Aymeric Augustin's avatar
Aymeric Augustin committed
30
31
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
thomwolf's avatar
thomwolf committed
32

Aymeric Augustin's avatar
Aymeric Augustin committed
33
from run_glue import ALL_MODELS, MODEL_CLASSES, load_and_cache_examples, set_seed
34
35
36
37
38
39
40
41
42
43
44
45
from transformers import (
    WEIGHTS_NAME,
    BertConfig,
    BertForSequenceClassification,
    BertTokenizer,
    XLMConfig,
    XLMForSequenceClassification,
    XLMTokenizer,
    XLNetConfig,
    XLNetForSequenceClassification,
    XLNetTokenizer,
)
Adrian Bauer's avatar
Adrian Bauer committed
46
47
48
from transformers import glue_compute_metrics as compute_metrics
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors
thomwolf's avatar
thomwolf committed
49

Aymeric Augustin's avatar
Aymeric Augustin committed
50

51
52
logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
53
54

def entropy(p):
thomwolf's avatar
thomwolf committed
55
    """ Compute the entropy of a probability distribution """
thomwolf's avatar
thomwolf committed
56
57
58
59
    plogp = p * torch.log(p)
    plogp[p == 0] = 0
    return -plogp.sum(dim=-1)

thomwolf's avatar
thomwolf committed
60

thomwolf's avatar
thomwolf committed
61
def print_2d_tensor(tensor):
thomwolf's avatar
thomwolf committed
62
    """ Print a 2D tensor """
thomwolf's avatar
thomwolf committed
63
64
    logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
    for row in range(len(tensor)):
thomwolf's avatar
thomwolf committed
65
66
67
68
        if tensor.dtype != torch.long:
            logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data))
        else:
            logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
thomwolf's avatar
thomwolf committed
69

thomwolf's avatar
thomwolf committed
70

71
72
73
def compute_heads_importance(
    args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None
):
thomwolf's avatar
thomwolf committed
74
75
    """ This method shows how to compute:
        - head attention entropy
thomwolf's avatar
thomwolf committed
76
77
        - head importance scores according to http://arxiv.org/abs/1905.10650
    """
thomwolf's avatar
thomwolf committed
78
79
80
81
    # Prepare our tensors
    n_layers, n_heads = model.bert.config.num_hidden_layers, model.bert.config.num_attention_heads
    head_importance = torch.zeros(n_layers, n_heads).to(args.device)
    attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
thomwolf's avatar
thomwolf committed
82
83
84
85

    if head_mask is None:
        head_mask = torch.ones(n_layers, n_heads).to(args.device)
    head_mask.requires_grad_(requires_grad=True)
thomwolf's avatar
thomwolf committed
86
87
88
89
    preds = None
    labels = None
    tot_tokens = 0.0

thomwolf's avatar
thomwolf committed
90
91
92
93
    for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
        batch = tuple(t.to(args.device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids = batch

thomwolf's avatar
thomwolf committed
94
        # Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
95
96
97
98
99
100
101
102
        outputs = model(
            input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids, head_mask=head_mask
        )
        loss, logits, all_attentions = (
            outputs[0],
            outputs[1],
            outputs[-1],
        )  # Loss and logits are the first, attention the last
thomwolf's avatar
thomwolf committed
103
        loss.backward()  # Backpropagate to populate the gradients in the head mask
thomwolf's avatar
thomwolf committed
104
105
106
107
108
109
110

        if compute_entropy:
            for layer, attn in enumerate(all_attentions):
                masked_entropy = entropy(attn.detach()) * input_mask.float().unsqueeze(1)
                attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()

        if compute_importance:
thomwolf's avatar
thomwolf committed
111
            head_importance += head_mask.grad.abs().detach()
thomwolf's avatar
thomwolf committed
112
113
114
115
116
117
118
119

        # Also store our logits/labels if we want to compute metrics afterwards
        if preds is None:
            preds = logits.detach().cpu().numpy()
            labels = label_ids.detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            labels = np.append(labels, label_ids.detach().cpu().numpy(), axis=0)
thomwolf's avatar
thomwolf committed
120
121
122
123
124
125

        tot_tokens += input_mask.float().detach().sum().data

    # Normalize
    attn_entropy /= tot_tokens
    head_importance /= tot_tokens
thomwolf's avatar
thomwolf committed
126
127
128
    # Layerwise importance normalization
    if not args.dont_normalize_importance_by_layer:
        exponent = 2
129
        norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
thomwolf's avatar
thomwolf committed
130
131
132
        head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20

    if not args.dont_normalize_global_importance:
thomwolf's avatar
thomwolf committed
133
134
        head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())

thomwolf's avatar
thomwolf committed
135
    # Print/save matrices
136
137
    np.save(os.path.join(args.output_dir, "attn_entropy.npy"), attn_entropy.detach().cpu().numpy())
    np.save(os.path.join(args.output_dir, "head_importance.npy"), head_importance.detach().cpu().numpy())
thomwolf's avatar
thomwolf committed
138
139
140
141
142
143
144

    logger.info("Attention entropies")
    print_2d_tensor(attn_entropy)
    logger.info("Head importance scores")
    print_2d_tensor(head_importance)
    logger.info("Head ranked by importance scores")
    head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
145
146
147
    head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
        head_importance.numel(), device=args.device
    )
thomwolf's avatar
thomwolf committed
148
149
150
    head_ranks = head_ranks.view_as(head_importance)
    print_2d_tensor(head_ranks)

thomwolf's avatar
thomwolf committed
151
    return attn_entropy, head_importance, preds, labels
thomwolf's avatar
thomwolf committed
152

thomwolf's avatar
thomwolf committed
153

thomwolf's avatar
thomwolf committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def mask_heads(args, model, eval_dataloader):
    """ This method shows how to mask head (set some heads to zero), to test the effect on the network,
        based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650)
    """
    _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
    original_score = compute_metrics(args.task_name, preds, labels)[args.metric_name]
    logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)

    new_head_mask = torch.ones_like(head_importance)
    num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))

    current_score = original_score
    while current_score >= original_score * args.masking_threshold:
168
        head_mask = new_head_mask.clone()  # save current head mask
thomwolf's avatar
thomwolf committed
169
        # heads from least important to most - keep only not-masked heads
170
        head_importance[head_mask == 0.0] = float("Inf")
thomwolf's avatar
thomwolf committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        current_heads_to_mask = head_importance.view(-1).sort()[1]

        if len(current_heads_to_mask) <= num_to_mask:
            break

        # mask heads
        current_heads_to_mask = current_heads_to_mask[:num_to_mask]
        logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
        new_head_mask = new_head_mask.view(-1)
        new_head_mask[current_heads_to_mask] = 0.0
        new_head_mask = new_head_mask.view_as(head_mask)
        print_2d_tensor(new_head_mask)

        # Compute metric and head importance again
185
186
187
        _, head_importance, preds, labels = compute_heads_importance(
            args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask
        )
thomwolf's avatar
thomwolf committed
188
189
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
        current_score = compute_metrics(args.task_name, preds, labels)[args.metric_name]
190
191
192
193
194
195
        logger.info(
            "Masking: current score: %f, remaning heads %d (%.1f percents)",
            current_score,
            new_head_mask.sum(),
            new_head_mask.sum() / new_head_mask.numel() * 100,
        )
thomwolf's avatar
thomwolf committed
196
197
198

    logger.info("Final head mask")
    print_2d_tensor(head_mask)
199
    np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy())
thomwolf's avatar
thomwolf committed
200
201
202
203
204
205
206
207
208
209
210

    return head_mask


def prune_heads(args, model, eval_dataloader, head_mask):
    """ This method shows how to prune head (remove heads weights) based on
        the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650)
    """
    # Try pruning and test time speedup
    # Pruning is like masking but we actually remove the masked weights
    before_time = datetime.now()
211
212
213
    _, _, preds, labels = compute_heads_importance(
        args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask
    )
thomwolf's avatar
thomwolf committed
214
215
216
217
218
219
220
221
222
223
224
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
    score_masking = compute_metrics(args.task_name, preds, labels)[args.metric_name]
    original_time = datetime.now() - before_time

    original_num_params = sum(p.numel() for p in model.parameters())
    heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask)))
    assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
    model.prune_heads(heads_to_prune)
    pruned_num_params = sum(p.numel() for p in model.parameters())

    before_time = datetime.now()
225
226
227
    _, _, preds, labels = compute_heads_importance(
        args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=None
    )
thomwolf's avatar
thomwolf committed
228
229
230
231
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
    score_pruning = compute_metrics(args.task_name, preds, labels)[args.metric_name]
    new_time = datetime.now() - before_time

232
233
234
235
236
237
    logger.info(
        "Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)",
        original_num_params,
        pruned_num_params,
        pruned_num_params / original_num_params * 100,
    )
thomwolf's avatar
thomwolf committed
238
    logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
239
    logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time / new_time * 100)
thomwolf's avatar
thomwolf committed
240
241
242


def main():
243
    parser = argparse.ArgumentParser()
244
    # Required parameters
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
    )
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
thomwolf's avatar
thomwolf committed
273

274
    # Other parameters
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )
    parser.add_argument(
        "--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
    )
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )

    parser.add_argument(
        "--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
    )
    parser.add_argument(
        "--dont_normalize_global_importance",
        action="store_true",
        help="Don't normalize all importance scores between 0 and 1",
    )

    parser.add_argument(
        "--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
    )
    parser.add_argument(
        "--masking_threshold",
        default=0.9,
        type=float,
        help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
    )
    parser.add_argument(
        "--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
    )
    parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")

    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, sequences shorter padded.",
    )
thomwolf's avatar
thomwolf committed
333
334
    parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")

335
    parser.add_argument("--seed", type=int, default=42)
thomwolf's avatar
thomwolf committed
336
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
337
338
339
    parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
    parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
340
341
    args = parser.parse_args()

thomwolf's avatar
thomwolf committed
342
343
344
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
345

thomwolf's avatar
thomwolf committed
346
347
348
349
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
350
    # Setup devices and distributed training
thomwolf's avatar
thomwolf committed
351
352
    if args.local_rank == -1 or args.no_cuda:
        args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
thomwolf's avatar
thomwolf committed
353
        args.n_gpu = torch.cuda.device_count()
thomwolf's avatar
thomwolf committed
354
355
356
    else:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
thomwolf's avatar
thomwolf committed
357
        args.n_gpu = 1
358
        torch.distributed.init_process_group(backend="nccl")  # Initializes the distributed backend
thomwolf's avatar
thomwolf committed
359

thomwolf's avatar
thomwolf committed
360
    # Setup logging
361
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
thomwolf's avatar
thomwolf committed
362
    logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1)))
363

thomwolf's avatar
thomwolf committed
364
    # Set seeds
thomwolf's avatar
thomwolf committed
365
    set_seed(args)
thomwolf's avatar
thomwolf committed
366
367

    # Prepare GLUE task
thomwolf's avatar
thomwolf committed
368
369
370
371
372
    args.task_name = args.task_name.lower()
    if args.task_name not in processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
thomwolf's avatar
thomwolf committed
373
    label_list = processor.get_labels()
thomwolf's avatar
thomwolf committed
374
    num_labels = len(label_list)
thomwolf's avatar
thomwolf committed
375

thomwolf's avatar
thomwolf committed
376
    # Load pretrained model and tokenizer
thomwolf's avatar
thomwolf committed
377
    if args.local_rank not in [-1, 0]:
thomwolf's avatar
thomwolf committed
378
379
380
381
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = ""
    for key in MODEL_CLASSES:
tuvuumass's avatar
tuvuumass committed
382
        if key in args.model_name_or_path.lower():
thomwolf's avatar
thomwolf committed
383
384
385
            args.model_type = key  # take the first match in model types
            break
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        output_attentions=True,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
thomwolf's avatar
thomwolf committed
403

thomwolf's avatar
thomwolf committed
404
    if args.local_rank == 0:
thomwolf's avatar
thomwolf committed
405
406
407
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    # Distributed and parallel training
thomwolf's avatar
thomwolf committed
408
    model.to(args.device)
thomwolf's avatar
thomwolf committed
409
    if args.local_rank != -1:
410
411
412
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
thomwolf's avatar
thomwolf committed
413
414
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
415

thomwolf's avatar
thomwolf committed
416
    # Print/save training arguments
417
    torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
thomwolf's avatar
thomwolf committed
418
    logger.info("Training/evaluation parameters %s", args)
thomwolf's avatar
thomwolf committed
419

thomwolf's avatar
thomwolf committed
420
421
    # Prepare dataset for the GLUE task
    eval_data = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=True)
thomwolf's avatar
thomwolf committed
422
    if args.data_subset > 0:
thomwolf's avatar
thomwolf committed
423
        eval_data = Subset(eval_data, list(range(min(args.data_subset, len(eval_data)))))
thomwolf's avatar
thomwolf committed
424
425
426
427
    eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)

    # Compute head entropy and importance score
thomwolf's avatar
thomwolf committed
428
    compute_heads_importance(args, model, eval_dataloader)
thomwolf's avatar
thomwolf committed
429

thomwolf's avatar
thomwolf committed
430
431
    # Try head masking (set heads to zero until the score goes under a threshole)
    # and head pruning (remove masked heads and see the effect on the network)
thomwolf's avatar
thomwolf committed
432
    if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
thomwolf's avatar
thomwolf committed
433
434
        head_mask = mask_heads(args, model, eval_dataloader)
        prune_heads(args, model, eval_dataloader, head_mask)
thomwolf's avatar
thomwolf committed
435

thomwolf's avatar
thomwolf committed
436

437
if __name__ == "__main__":
thomwolf's avatar
thomwolf committed
438
    main()