run_bertology.py 17.9 KB
Newer Older
1
#!/usr/bin/env python3
thomwolf's avatar
thomwolf committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2018 CMU and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Bertology: this script shows how you can explore the internals of the models in the library to:
    - compute the entropy of the head attentions
    - compute the importance of each head
    - prune (remove) the low importance head.
    Some parts of this script are adapted from the code of Michel et al. (http://arxiv.org/abs/1905.10650)
    which is available at https://github.com/pmichel31415/are-16-heads-really-better-than-1
"""
22
23
import argparse
import logging
Aymeric Augustin's avatar
Aymeric Augustin committed
24
import os
25
from datetime import datetime
26
27

import numpy as np
thomwolf's avatar
thomwolf committed
28
import torch
29
from torch.utils.data import DataLoader, SequentialSampler, Subset
Aymeric Augustin's avatar
Aymeric Augustin committed
30
31
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
thomwolf's avatar
thomwolf committed
32

Aymeric Augustin's avatar
Aymeric Augustin committed
33
from run_glue import ALL_MODELS, MODEL_CLASSES, load_and_cache_examples, set_seed
Adrian Bauer's avatar
Adrian Bauer committed
34
35
36
from transformers import glue_compute_metrics as compute_metrics
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors
thomwolf's avatar
thomwolf committed
37

Aymeric Augustin's avatar
Aymeric Augustin committed
38

39
40
logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
41
42

def entropy(p):
thomwolf's avatar
thomwolf committed
43
    """ Compute the entropy of a probability distribution """
thomwolf's avatar
thomwolf committed
44
45
46
47
    plogp = p * torch.log(p)
    plogp[p == 0] = 0
    return -plogp.sum(dim=-1)

thomwolf's avatar
thomwolf committed
48

thomwolf's avatar
thomwolf committed
49
def print_2d_tensor(tensor):
thomwolf's avatar
thomwolf committed
50
    """ Print a 2D tensor """
thomwolf's avatar
thomwolf committed
51
52
    logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
    for row in range(len(tensor)):
thomwolf's avatar
thomwolf committed
53
54
55
56
        if tensor.dtype != torch.long:
            logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data))
        else:
            logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
thomwolf's avatar
thomwolf committed
57

thomwolf's avatar
thomwolf committed
58

59
60
61
def compute_heads_importance(
    args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None
):
thomwolf's avatar
thomwolf committed
62
63
    """ This method shows how to compute:
        - head attention entropy
thomwolf's avatar
thomwolf committed
64
65
        - head importance scores according to http://arxiv.org/abs/1905.10650
    """
thomwolf's avatar
thomwolf committed
66
67
68
69
    # Prepare our tensors
    n_layers, n_heads = model.bert.config.num_hidden_layers, model.bert.config.num_attention_heads
    head_importance = torch.zeros(n_layers, n_heads).to(args.device)
    attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
thomwolf's avatar
thomwolf committed
70
71
72
73

    if head_mask is None:
        head_mask = torch.ones(n_layers, n_heads).to(args.device)
    head_mask.requires_grad_(requires_grad=True)
thomwolf's avatar
thomwolf committed
74
75
76
77
    preds = None
    labels = None
    tot_tokens = 0.0

thomwolf's avatar
thomwolf committed
78
79
80
81
    for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
        batch = tuple(t.to(args.device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids = batch

thomwolf's avatar
thomwolf committed
82
        # Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
83
84
85
86
87
88
89
90
        outputs = model(
            input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids, head_mask=head_mask
        )
        loss, logits, all_attentions = (
            outputs[0],
            outputs[1],
            outputs[-1],
        )  # Loss and logits are the first, attention the last
thomwolf's avatar
thomwolf committed
91
        loss.backward()  # Backpropagate to populate the gradients in the head mask
thomwolf's avatar
thomwolf committed
92
93
94
95
96
97
98

        if compute_entropy:
            for layer, attn in enumerate(all_attentions):
                masked_entropy = entropy(attn.detach()) * input_mask.float().unsqueeze(1)
                attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()

        if compute_importance:
thomwolf's avatar
thomwolf committed
99
            head_importance += head_mask.grad.abs().detach()
thomwolf's avatar
thomwolf committed
100
101
102
103
104
105
106
107

        # Also store our logits/labels if we want to compute metrics afterwards
        if preds is None:
            preds = logits.detach().cpu().numpy()
            labels = label_ids.detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            labels = np.append(labels, label_ids.detach().cpu().numpy(), axis=0)
thomwolf's avatar
thomwolf committed
108
109
110
111
112
113

        tot_tokens += input_mask.float().detach().sum().data

    # Normalize
    attn_entropy /= tot_tokens
    head_importance /= tot_tokens
thomwolf's avatar
thomwolf committed
114
115
116
    # Layerwise importance normalization
    if not args.dont_normalize_importance_by_layer:
        exponent = 2
117
        norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
thomwolf's avatar
thomwolf committed
118
119
120
        head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20

    if not args.dont_normalize_global_importance:
thomwolf's avatar
thomwolf committed
121
122
        head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())

thomwolf's avatar
thomwolf committed
123
    # Print/save matrices
124
125
    np.save(os.path.join(args.output_dir, "attn_entropy.npy"), attn_entropy.detach().cpu().numpy())
    np.save(os.path.join(args.output_dir, "head_importance.npy"), head_importance.detach().cpu().numpy())
thomwolf's avatar
thomwolf committed
126
127
128
129
130
131
132

    logger.info("Attention entropies")
    print_2d_tensor(attn_entropy)
    logger.info("Head importance scores")
    print_2d_tensor(head_importance)
    logger.info("Head ranked by importance scores")
    head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
133
134
135
    head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
        head_importance.numel(), device=args.device
    )
thomwolf's avatar
thomwolf committed
136
137
138
    head_ranks = head_ranks.view_as(head_importance)
    print_2d_tensor(head_ranks)

thomwolf's avatar
thomwolf committed
139
    return attn_entropy, head_importance, preds, labels
thomwolf's avatar
thomwolf committed
140

thomwolf's avatar
thomwolf committed
141

thomwolf's avatar
thomwolf committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def mask_heads(args, model, eval_dataloader):
    """ This method shows how to mask head (set some heads to zero), to test the effect on the network,
        based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650)
    """
    _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
    original_score = compute_metrics(args.task_name, preds, labels)[args.metric_name]
    logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)

    new_head_mask = torch.ones_like(head_importance)
    num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))

    current_score = original_score
    while current_score >= original_score * args.masking_threshold:
156
        head_mask = new_head_mask.clone()  # save current head mask
thomwolf's avatar
thomwolf committed
157
        # heads from least important to most - keep only not-masked heads
158
        head_importance[head_mask == 0.0] = float("Inf")
thomwolf's avatar
thomwolf committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        current_heads_to_mask = head_importance.view(-1).sort()[1]

        if len(current_heads_to_mask) <= num_to_mask:
            break

        # mask heads
        current_heads_to_mask = current_heads_to_mask[:num_to_mask]
        logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
        new_head_mask = new_head_mask.view(-1)
        new_head_mask[current_heads_to_mask] = 0.0
        new_head_mask = new_head_mask.view_as(head_mask)
        print_2d_tensor(new_head_mask)

        # Compute metric and head importance again
173
174
175
        _, head_importance, preds, labels = compute_heads_importance(
            args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask
        )
thomwolf's avatar
thomwolf committed
176
177
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
        current_score = compute_metrics(args.task_name, preds, labels)[args.metric_name]
178
179
180
181
182
183
        logger.info(
            "Masking: current score: %f, remaning heads %d (%.1f percents)",
            current_score,
            new_head_mask.sum(),
            new_head_mask.sum() / new_head_mask.numel() * 100,
        )
thomwolf's avatar
thomwolf committed
184
185
186

    logger.info("Final head mask")
    print_2d_tensor(head_mask)
187
    np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy())
thomwolf's avatar
thomwolf committed
188
189
190
191
192
193
194
195
196
197
198

    return head_mask


def prune_heads(args, model, eval_dataloader, head_mask):
    """ This method shows how to prune head (remove heads weights) based on
        the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650)
    """
    # Try pruning and test time speedup
    # Pruning is like masking but we actually remove the masked weights
    before_time = datetime.now()
199
200
201
    _, _, preds, labels = compute_heads_importance(
        args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask
    )
thomwolf's avatar
thomwolf committed
202
203
204
205
206
207
208
209
210
211
212
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
    score_masking = compute_metrics(args.task_name, preds, labels)[args.metric_name]
    original_time = datetime.now() - before_time

    original_num_params = sum(p.numel() for p in model.parameters())
    heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask)))
    assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
    model.prune_heads(heads_to_prune)
    pruned_num_params = sum(p.numel() for p in model.parameters())

    before_time = datetime.now()
213
214
215
    _, _, preds, labels = compute_heads_importance(
        args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=None
    )
thomwolf's avatar
thomwolf committed
216
217
218
219
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
    score_pruning = compute_metrics(args.task_name, preds, labels)[args.metric_name]
    new_time = datetime.now() - before_time

220
221
222
223
224
225
    logger.info(
        "Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)",
        original_num_params,
        pruned_num_params,
        pruned_num_params / original_num_params * 100,
    )
thomwolf's avatar
thomwolf committed
226
    logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
227
    logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time / new_time * 100)
thomwolf's avatar
thomwolf committed
228
229
230


def main():
231
    parser = argparse.ArgumentParser()
232
    # Required parameters
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
    )
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
thomwolf's avatar
thomwolf committed
261

262
    # Other parameters
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )
    parser.add_argument(
        "--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
    )
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )

    parser.add_argument(
        "--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
    )
    parser.add_argument(
        "--dont_normalize_global_importance",
        action="store_true",
        help="Don't normalize all importance scores between 0 and 1",
    )

    parser.add_argument(
        "--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
    )
    parser.add_argument(
        "--masking_threshold",
        default=0.9,
        type=float,
        help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
    )
    parser.add_argument(
        "--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
    )
    parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")

    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, sequences shorter padded.",
    )
thomwolf's avatar
thomwolf committed
321
322
    parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")

323
    parser.add_argument("--seed", type=int, default=42)
thomwolf's avatar
thomwolf committed
324
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
325
326
327
    parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
    parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
328
329
    args = parser.parse_args()

thomwolf's avatar
thomwolf committed
330
331
332
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
333

thomwolf's avatar
thomwolf committed
334
335
336
337
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
338
    # Setup devices and distributed training
thomwolf's avatar
thomwolf committed
339
340
    if args.local_rank == -1 or args.no_cuda:
        args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
thomwolf's avatar
thomwolf committed
341
        args.n_gpu = torch.cuda.device_count()
thomwolf's avatar
thomwolf committed
342
343
344
    else:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
thomwolf's avatar
thomwolf committed
345
        args.n_gpu = 1
346
        torch.distributed.init_process_group(backend="nccl")  # Initializes the distributed backend
thomwolf's avatar
thomwolf committed
347

thomwolf's avatar
thomwolf committed
348
    # Setup logging
349
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
thomwolf's avatar
thomwolf committed
350
    logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1)))
351

thomwolf's avatar
thomwolf committed
352
    # Set seeds
thomwolf's avatar
thomwolf committed
353
    set_seed(args)
thomwolf's avatar
thomwolf committed
354
355

    # Prepare GLUE task
thomwolf's avatar
thomwolf committed
356
357
358
359
360
    args.task_name = args.task_name.lower()
    if args.task_name not in processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
thomwolf's avatar
thomwolf committed
361
    label_list = processor.get_labels()
thomwolf's avatar
thomwolf committed
362
    num_labels = len(label_list)
thomwolf's avatar
thomwolf committed
363

thomwolf's avatar
thomwolf committed
364
    # Load pretrained model and tokenizer
thomwolf's avatar
thomwolf committed
365
    if args.local_rank not in [-1, 0]:
thomwolf's avatar
thomwolf committed
366
367
368
369
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = ""
    for key in MODEL_CLASSES:
tuvuumass's avatar
tuvuumass committed
370
        if key in args.model_name_or_path.lower():
thomwolf's avatar
thomwolf committed
371
372
373
            args.model_type = key  # take the first match in model types
            break
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        output_attentions=True,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
thomwolf's avatar
thomwolf committed
391

thomwolf's avatar
thomwolf committed
392
    if args.local_rank == 0:
thomwolf's avatar
thomwolf committed
393
394
395
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    # Distributed and parallel training
thomwolf's avatar
thomwolf committed
396
    model.to(args.device)
thomwolf's avatar
thomwolf committed
397
    if args.local_rank != -1:
398
399
400
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
thomwolf's avatar
thomwolf committed
401
402
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
403

thomwolf's avatar
thomwolf committed
404
    # Print/save training arguments
405
    torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
thomwolf's avatar
thomwolf committed
406
    logger.info("Training/evaluation parameters %s", args)
thomwolf's avatar
thomwolf committed
407

thomwolf's avatar
thomwolf committed
408
409
    # Prepare dataset for the GLUE task
    eval_data = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=True)
thomwolf's avatar
thomwolf committed
410
    if args.data_subset > 0:
thomwolf's avatar
thomwolf committed
411
        eval_data = Subset(eval_data, list(range(min(args.data_subset, len(eval_data)))))
thomwolf's avatar
thomwolf committed
412
413
414
415
    eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)

    # Compute head entropy and importance score
thomwolf's avatar
thomwolf committed
416
    compute_heads_importance(args, model, eval_dataloader)
thomwolf's avatar
thomwolf committed
417

thomwolf's avatar
thomwolf committed
418
419
    # Try head masking (set heads to zero until the score goes under a threshole)
    # and head pruning (remove masked heads and see the effect on the network)
thomwolf's avatar
thomwolf committed
420
    if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
thomwolf's avatar
thomwolf committed
421
422
        head_mask = mask_heads(args, model, eval_dataloader)
        prune_heads(args, model, eval_dataloader, head_mask)
thomwolf's avatar
thomwolf committed
423

thomwolf's avatar
thomwolf committed
424

425
if __name__ == "__main__":
thomwolf's avatar
thomwolf committed
426
    main()