run_bertology.py 18.2 KB
Newer Older
1
#!/usr/bin/env python3
thomwolf's avatar
thomwolf committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2018 CMU and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Bertology: this script shows how you can explore the internals of the models in the library to:
    - compute the entropy of the head attentions
    - compute the importance of each head
    - prune (remove) the low importance head.
    Some parts of this script are adapted from the code of Michel et al. (http://arxiv.org/abs/1905.10650)
    which is available at https://github.com/pmichel31415/are-16-heads-really-better-than-1
"""
22
23
import argparse
import logging
Aymeric Augustin's avatar
Aymeric Augustin committed
24
import os
25
from datetime import datetime
26
27

import numpy as np
thomwolf's avatar
thomwolf committed
28
import torch
29
from torch import nn
30
from torch.utils.data import DataLoader, SequentialSampler, Subset
Aymeric Augustin's avatar
Aymeric Augustin committed
31
32
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
thomwolf's avatar
thomwolf committed
33

34
import transformers
Julien Chaumond's avatar
Julien Chaumond committed
35
36
37
38
39
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    GlueDataset,
40
    default_data_collator,
Julien Chaumond's avatar
Julien Chaumond committed
41
42
43
44
45
    glue_compute_metrics,
    glue_output_modes,
    glue_processors,
    set_seed,
)
46
from transformers.trainer_utils import is_main_process
thomwolf's avatar
thomwolf committed
47

Aymeric Augustin's avatar
Aymeric Augustin committed
48

49
50
logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
51
52

def entropy(p):
Patrick von Platen's avatar
Patrick von Platen committed
53
    """Compute the entropy of a probability distribution"""
thomwolf's avatar
thomwolf committed
54
55
56
57
    plogp = p * torch.log(p)
    plogp[p == 0] = 0
    return -plogp.sum(dim=-1)

thomwolf's avatar
thomwolf committed
58

thomwolf's avatar
thomwolf committed
59
def print_2d_tensor(tensor):
Patrick von Platen's avatar
Patrick von Platen committed
60
    """Print a 2D tensor"""
thomwolf's avatar
thomwolf committed
61
62
    logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
    for row in range(len(tensor)):
thomwolf's avatar
thomwolf committed
63
64
65
66
        if tensor.dtype != torch.long:
            logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data))
        else:
            logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
thomwolf's avatar
thomwolf committed
67

thomwolf's avatar
thomwolf committed
68

69
def compute_heads_importance(
70
    args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False
71
):
Lysandre's avatar
Lysandre committed
72
73
74
    """This method shows how to compute:
    - head attention entropy
    - head importance scores according to http://arxiv.org/abs/1905.10650
thomwolf's avatar
thomwolf committed
75
    """
thomwolf's avatar
thomwolf committed
76
    # Prepare our tensors
Julien Chaumond's avatar
Julien Chaumond committed
77
    n_layers, n_heads = model.config.num_hidden_layers, model.config.num_attention_heads
thomwolf's avatar
thomwolf committed
78
79
    head_importance = torch.zeros(n_layers, n_heads).to(args.device)
    attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
thomwolf's avatar
thomwolf committed
80
81
82

    if head_mask is None:
        head_mask = torch.ones(n_layers, n_heads).to(args.device)
83

thomwolf's avatar
thomwolf committed
84
    head_mask.requires_grad_(requires_grad=True)
85
86
87
88
    # If actually pruned attention multi-head, set head mask to None to avoid shape mismatch
    if actually_pruned:
        head_mask = None

thomwolf's avatar
thomwolf committed
89
90
91
92
    preds = None
    labels = None
    tot_tokens = 0.0

Julien Chaumond's avatar
Julien Chaumond committed
93
94
95
    for step, inputs in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
        for k, v in inputs.items():
            inputs[k] = v.to(args.device)
thomwolf's avatar
thomwolf committed
96

thomwolf's avatar
thomwolf committed
97
        # Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
Julien Chaumond's avatar
Julien Chaumond committed
98
        outputs = model(**inputs, head_mask=head_mask)
99
100
101
102
103
        loss, logits, all_attentions = (
            outputs[0],
            outputs[1],
            outputs[-1],
        )  # Loss and logits are the first, attention the last
thomwolf's avatar
thomwolf committed
104
        loss.backward()  # Backpropagate to populate the gradients in the head mask
thomwolf's avatar
thomwolf committed
105
106
107

        if compute_entropy:
            for layer, attn in enumerate(all_attentions):
Julien Chaumond's avatar
Julien Chaumond committed
108
                masked_entropy = entropy(attn.detach()) * inputs["attention_mask"].float().unsqueeze(1)
thomwolf's avatar
thomwolf committed
109
110
111
                attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()

        if compute_importance:
thomwolf's avatar
thomwolf committed
112
            head_importance += head_mask.grad.abs().detach()
thomwolf's avatar
thomwolf committed
113
114
115
116

        # Also store our logits/labels if we want to compute metrics afterwards
        if preds is None:
            preds = logits.detach().cpu().numpy()
Julien Chaumond's avatar
Julien Chaumond committed
117
            labels = inputs["labels"].detach().cpu().numpy()
thomwolf's avatar
thomwolf committed
118
119
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
Julien Chaumond's avatar
Julien Chaumond committed
120
            labels = np.append(labels, inputs["labels"].detach().cpu().numpy(), axis=0)
thomwolf's avatar
thomwolf committed
121

Julien Chaumond's avatar
Julien Chaumond committed
122
        tot_tokens += inputs["attention_mask"].float().detach().sum().data
thomwolf's avatar
thomwolf committed
123
124
125
126

    # Normalize
    attn_entropy /= tot_tokens
    head_importance /= tot_tokens
thomwolf's avatar
thomwolf committed
127
128
129
    # Layerwise importance normalization
    if not args.dont_normalize_importance_by_layer:
        exponent = 2
130
        norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
thomwolf's avatar
thomwolf committed
131
132
133
        head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20

    if not args.dont_normalize_global_importance:
thomwolf's avatar
thomwolf committed
134
135
        head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())

thomwolf's avatar
thomwolf committed
136
    # Print/save matrices
137
138
    np.save(os.path.join(args.output_dir, "attn_entropy.npy"), attn_entropy.detach().cpu().numpy())
    np.save(os.path.join(args.output_dir, "head_importance.npy"), head_importance.detach().cpu().numpy())
thomwolf's avatar
thomwolf committed
139
140
141
142
143
144
145

    logger.info("Attention entropies")
    print_2d_tensor(attn_entropy)
    logger.info("Head importance scores")
    print_2d_tensor(head_importance)
    logger.info("Head ranked by importance scores")
    head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
146
147
148
    head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
        head_importance.numel(), device=args.device
    )
thomwolf's avatar
thomwolf committed
149
150
151
    head_ranks = head_ranks.view_as(head_importance)
    print_2d_tensor(head_ranks)

thomwolf's avatar
thomwolf committed
152
    return attn_entropy, head_importance, preds, labels
thomwolf's avatar
thomwolf committed
153

thomwolf's avatar
thomwolf committed
154

thomwolf's avatar
thomwolf committed
155
def mask_heads(args, model, eval_dataloader):
Lysandre's avatar
Lysandre committed
156
157
    """This method shows how to mask head (set some heads to zero), to test the effect on the network,
    based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650)
thomwolf's avatar
thomwolf committed
158
159
160
    """
    _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
Julien Chaumond's avatar
Julien Chaumond committed
161
    original_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
thomwolf's avatar
thomwolf committed
162
163
164
165
166
167
168
    logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)

    new_head_mask = torch.ones_like(head_importance)
    num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))

    current_score = original_score
    while current_score >= original_score * args.masking_threshold:
169
        head_mask = new_head_mask.clone()  # save current head mask
thomwolf's avatar
thomwolf committed
170
        # heads from least important to most - keep only not-masked heads
171
        head_importance[head_mask == 0.0] = float("Inf")
thomwolf's avatar
thomwolf committed
172
173
174
175
176
177
178
179
180
181
182
        current_heads_to_mask = head_importance.view(-1).sort()[1]

        if len(current_heads_to_mask) <= num_to_mask:
            break

        # mask heads
        current_heads_to_mask = current_heads_to_mask[:num_to_mask]
        logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
        new_head_mask = new_head_mask.view(-1)
        new_head_mask[current_heads_to_mask] = 0.0
        new_head_mask = new_head_mask.view_as(head_mask)
183
        new_head_mask = new_head_mask.clone().detach()
thomwolf's avatar
thomwolf committed
184
185
186
        print_2d_tensor(new_head_mask)

        # Compute metric and head importance again
187
188
189
        _, head_importance, preds, labels = compute_heads_importance(
            args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask
        )
thomwolf's avatar
thomwolf committed
190
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
Julien Chaumond's avatar
Julien Chaumond committed
191
        current_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
192
        logger.info(
193
            "Masking: current score: %f, remaining heads %d (%.1f percents)",
194
195
196
197
            current_score,
            new_head_mask.sum(),
            new_head_mask.sum() / new_head_mask.numel() * 100,
        )
thomwolf's avatar
thomwolf committed
198
199
200

    logger.info("Final head mask")
    print_2d_tensor(head_mask)
201
    np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy())
thomwolf's avatar
thomwolf committed
202
203
204
205
206

    return head_mask


def prune_heads(args, model, eval_dataloader, head_mask):
Lysandre's avatar
Lysandre committed
207
208
    """This method shows how to prune head (remove heads weights) based on
    the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650)
thomwolf's avatar
thomwolf committed
209
210
211
212
    """
    # Try pruning and test time speedup
    # Pruning is like masking but we actually remove the masked weights
    before_time = datetime.now()
213
214
215
    _, _, preds, labels = compute_heads_importance(
        args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask
    )
thomwolf's avatar
thomwolf committed
216
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
Julien Chaumond's avatar
Julien Chaumond committed
217
    score_masking = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
thomwolf's avatar
thomwolf committed
218
219
220
    original_time = datetime.now() - before_time

    original_num_params = sum(p.numel() for p in model.parameters())
221
222
223
    heads_to_prune = {
        layer: (1 - head_mask[layer].long()).nonzero().squeeze().tolist() for layer in range(len(head_mask))
    }
224

thomwolf's avatar
thomwolf committed
225
226
227
228
229
    assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
    model.prune_heads(heads_to_prune)
    pruned_num_params = sum(p.numel() for p in model.parameters())

    before_time = datetime.now()
230
    _, _, preds, labels = compute_heads_importance(
231
232
233
234
235
236
237
        args,
        model,
        eval_dataloader,
        compute_entropy=False,
        compute_importance=False,
        head_mask=None,
        actually_pruned=True,
238
    )
thomwolf's avatar
thomwolf committed
239
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
Julien Chaumond's avatar
Julien Chaumond committed
240
    score_pruning = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
thomwolf's avatar
thomwolf committed
241
242
    new_time = datetime.now() - before_time

243
244
245
246
247
248
    logger.info(
        "Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)",
        original_num_params,
        pruned_num_params,
        pruned_num_params / original_num_params * 100,
    )
thomwolf's avatar
thomwolf committed
249
    logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
250
    logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time / new_time * 100)
thomwolf's avatar
thomwolf committed
251
252
253


def main():
254
    parser = argparse.ArgumentParser()
255
    # Required parameters
256
257
258
259
260
261
262
263
264
265
266
267
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
Julien Chaumond's avatar
Julien Chaumond committed
268
        help="Path to pretrained model or model identifier from huggingface.co/models",
269
270
271
272
273
274
    )
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
Julien Chaumond's avatar
Julien Chaumond committed
275
        help="The name of the task to train selected in the list: " + ", ".join(glue_processors.keys()),
276
277
278
279
280
281
282
283
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
thomwolf's avatar
thomwolf committed
284

285
    # Other parameters
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--cache_dir",
Julien Chaumond's avatar
Julien Chaumond committed
300
        default=None,
301
        type=str,
302
        help="Where do you want to store the pre-trained models downloaded from huggingface.co",
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    )
    parser.add_argument(
        "--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
    )
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )

    parser.add_argument(
        "--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
    )
    parser.add_argument(
        "--dont_normalize_global_importance",
        action="store_true",
        help="Don't normalize all importance scores between 0 and 1",
    )

    parser.add_argument(
        "--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
    )
    parser.add_argument(
        "--masking_threshold",
        default=0.9,
        type=float,
        help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
    )
    parser.add_argument(
        "--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
    )
    parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")

    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
Sylvain Gugger's avatar
Sylvain Gugger committed
341
342
343
344
        help=(
            "The maximum total input sequence length after WordPiece tokenization. \n"
            "Sequences longer than this will be truncated, sequences shorter padded."
        ),
345
    )
thomwolf's avatar
thomwolf committed
346
347
    parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")

348
    parser.add_argument("--seed", type=int, default=42)
thomwolf's avatar
thomwolf committed
349
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
350
351
352
    parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
    parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
353
354
    args = parser.parse_args()

thomwolf's avatar
thomwolf committed
355
356
357
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
358

thomwolf's avatar
thomwolf committed
359
360
361
362
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
363
    # Setup devices and distributed training
thomwolf's avatar
thomwolf committed
364
365
    if args.local_rank == -1 or args.no_cuda:
        args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
366
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
thomwolf's avatar
thomwolf committed
367
368
369
    else:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
thomwolf's avatar
thomwolf committed
370
        args.n_gpu = 1
371
        torch.distributed.init_process_group(backend="nccl")  # Initializes the distributed backend
thomwolf's avatar
thomwolf committed
372

thomwolf's avatar
thomwolf committed
373
    # Setup logging
374
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
thomwolf's avatar
thomwolf committed
375
    logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1)))
376
377
378
379
380
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
381

thomwolf's avatar
thomwolf committed
382
    # Set seeds
Julien Chaumond's avatar
Julien Chaumond committed
383
    set_seed(args.seed)
thomwolf's avatar
thomwolf committed
384
385

    # Prepare GLUE task
thomwolf's avatar
thomwolf committed
386
    args.task_name = args.task_name.lower()
Julien Chaumond's avatar
Julien Chaumond committed
387
    if args.task_name not in glue_processors:
thomwolf's avatar
thomwolf committed
388
        raise ValueError("Task not found: %s" % (args.task_name))
Julien Chaumond's avatar
Julien Chaumond committed
389
390
    processor = glue_processors[args.task_name]()
    args.output_mode = glue_output_modes[args.task_name]
thomwolf's avatar
thomwolf committed
391
    label_list = processor.get_labels()
thomwolf's avatar
thomwolf committed
392
    num_labels = len(label_list)
thomwolf's avatar
thomwolf committed
393

thomwolf's avatar
thomwolf committed
394
    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
395
396
397
398
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
thomwolf's avatar
thomwolf committed
399

Julien Chaumond's avatar
Julien Chaumond committed
400
    config = AutoConfig.from_pretrained(
401
402
403
404
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        output_attentions=True,
Julien Chaumond's avatar
Julien Chaumond committed
405
        cache_dir=args.cache_dir,
406
    )
Julien Chaumond's avatar
Julien Chaumond committed
407
    tokenizer = AutoTokenizer.from_pretrained(
Lysandre's avatar
Lysandre committed
408
409
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        cache_dir=args.cache_dir,
410
    )
Julien Chaumond's avatar
Julien Chaumond committed
411
    model = AutoModelForSequenceClassification.from_pretrained(
412
413
414
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
Julien Chaumond's avatar
Julien Chaumond committed
415
        cache_dir=args.cache_dir,
416
    )
thomwolf's avatar
thomwolf committed
417
418

    # Distributed and parallel training
thomwolf's avatar
thomwolf committed
419
    model.to(args.device)
thomwolf's avatar
thomwolf committed
420
    if args.local_rank != -1:
421
        model = nn.parallel.DistributedDataParallel(
422
423
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
thomwolf's avatar
thomwolf committed
424
    elif args.n_gpu > 1:
425
        model = nn.DataParallel(model)
426

thomwolf's avatar
thomwolf committed
427
    # Print/save training arguments
Julien Chaumond's avatar
Julien Chaumond committed
428
    os.makedirs(args.output_dir, exist_ok=True)
429
    torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
thomwolf's avatar
thomwolf committed
430
    logger.info("Training/evaluation parameters %s", args)
thomwolf's avatar
thomwolf committed
431

thomwolf's avatar
thomwolf committed
432
    # Prepare dataset for the GLUE task
433
    eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
thomwolf's avatar
thomwolf committed
434
    if args.data_subset > 0:
Julien Chaumond's avatar
Julien Chaumond committed
435
436
437
        eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
    eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
    eval_dataloader = DataLoader(
438
        eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=default_data_collator
Julien Chaumond's avatar
Julien Chaumond committed
439
    )
thomwolf's avatar
thomwolf committed
440
441

    # Compute head entropy and importance score
thomwolf's avatar
thomwolf committed
442
    compute_heads_importance(args, model, eval_dataloader)
thomwolf's avatar
thomwolf committed
443

thomwolf's avatar
thomwolf committed
444
445
    # Try head masking (set heads to zero until the score goes under a threshole)
    # and head pruning (remove masked heads and see the effect on the network)
thomwolf's avatar
thomwolf committed
446
    if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
thomwolf's avatar
thomwolf committed
447
448
        head_mask = mask_heads(args, model, eval_dataloader)
        prune_heads(args, model, eval_dataloader, head_mask)
thomwolf's avatar
thomwolf committed
449

thomwolf's avatar
thomwolf committed
450

451
if __name__ == "__main__":
thomwolf's avatar
thomwolf committed
452
    main()