run_bertology.py 18.1 KB
Newer Older
1
#!/usr/bin/env python3
thomwolf's avatar
thomwolf committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2018 CMU and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Bertology: this script shows how you can explore the internals of the models in the library to:
    - compute the entropy of the head attentions
    - compute the importance of each head
    - prune (remove) the low importance head.
    Some parts of this script are adapted from the code of Michel et al. (http://arxiv.org/abs/1905.10650)
    which is available at https://github.com/pmichel31415/are-16-heads-really-better-than-1
"""
22
23
import argparse
import logging
Aymeric Augustin's avatar
Aymeric Augustin committed
24
import os
25
from datetime import datetime
26
27

import numpy as np
thomwolf's avatar
thomwolf committed
28
import torch
29
from torch.utils.data import DataLoader, SequentialSampler, Subset
Aymeric Augustin's avatar
Aymeric Augustin committed
30
31
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
thomwolf's avatar
thomwolf committed
32

33
import transformers
Julien Chaumond's avatar
Julien Chaumond committed
34
35
36
37
38
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    GlueDataset,
39
    default_data_collator,
Julien Chaumond's avatar
Julien Chaumond committed
40
41
42
43
44
    glue_compute_metrics,
    glue_output_modes,
    glue_processors,
    set_seed,
)
45
from transformers.trainer_utils import is_main_process
thomwolf's avatar
thomwolf committed
46

Aymeric Augustin's avatar
Aymeric Augustin committed
47

48
49
logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
50
51

def entropy(p):
thomwolf's avatar
thomwolf committed
52
    """ Compute the entropy of a probability distribution """
thomwolf's avatar
thomwolf committed
53
54
55
56
    plogp = p * torch.log(p)
    plogp[p == 0] = 0
    return -plogp.sum(dim=-1)

thomwolf's avatar
thomwolf committed
57

thomwolf's avatar
thomwolf committed
58
def print_2d_tensor(tensor):
thomwolf's avatar
thomwolf committed
59
    """ Print a 2D tensor """
thomwolf's avatar
thomwolf committed
60
61
    logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
    for row in range(len(tensor)):
thomwolf's avatar
thomwolf committed
62
63
64
65
        if tensor.dtype != torch.long:
            logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data))
        else:
            logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
thomwolf's avatar
thomwolf committed
66

thomwolf's avatar
thomwolf committed
67

68
def compute_heads_importance(
69
    args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False
70
):
Lysandre's avatar
Lysandre committed
71
72
73
    """This method shows how to compute:
    - head attention entropy
    - head importance scores according to http://arxiv.org/abs/1905.10650
thomwolf's avatar
thomwolf committed
74
    """
thomwolf's avatar
thomwolf committed
75
    # Prepare our tensors
Julien Chaumond's avatar
Julien Chaumond committed
76
    n_layers, n_heads = model.config.num_hidden_layers, model.config.num_attention_heads
thomwolf's avatar
thomwolf committed
77
78
    head_importance = torch.zeros(n_layers, n_heads).to(args.device)
    attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
thomwolf's avatar
thomwolf committed
79
80
81

    if head_mask is None:
        head_mask = torch.ones(n_layers, n_heads).to(args.device)
82

thomwolf's avatar
thomwolf committed
83
    head_mask.requires_grad_(requires_grad=True)
84
85
86
87
    # If actually pruned attention multi-head, set head mask to None to avoid shape mismatch
    if actually_pruned:
        head_mask = None

thomwolf's avatar
thomwolf committed
88
89
90
91
    preds = None
    labels = None
    tot_tokens = 0.0

Julien Chaumond's avatar
Julien Chaumond committed
92
93
94
    for step, inputs in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
        for k, v in inputs.items():
            inputs[k] = v.to(args.device)
thomwolf's avatar
thomwolf committed
95

thomwolf's avatar
thomwolf committed
96
        # Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
Julien Chaumond's avatar
Julien Chaumond committed
97
        outputs = model(**inputs, head_mask=head_mask)
98
99
100
101
102
        loss, logits, all_attentions = (
            outputs[0],
            outputs[1],
            outputs[-1],
        )  # Loss and logits are the first, attention the last
thomwolf's avatar
thomwolf committed
103
        loss.backward()  # Backpropagate to populate the gradients in the head mask
thomwolf's avatar
thomwolf committed
104
105
106

        if compute_entropy:
            for layer, attn in enumerate(all_attentions):
Julien Chaumond's avatar
Julien Chaumond committed
107
                masked_entropy = entropy(attn.detach()) * inputs["attention_mask"].float().unsqueeze(1)
thomwolf's avatar
thomwolf committed
108
109
110
                attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()

        if compute_importance:
thomwolf's avatar
thomwolf committed
111
            head_importance += head_mask.grad.abs().detach()
thomwolf's avatar
thomwolf committed
112
113
114
115

        # Also store our logits/labels if we want to compute metrics afterwards
        if preds is None:
            preds = logits.detach().cpu().numpy()
Julien Chaumond's avatar
Julien Chaumond committed
116
            labels = inputs["labels"].detach().cpu().numpy()
thomwolf's avatar
thomwolf committed
117
118
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
Julien Chaumond's avatar
Julien Chaumond committed
119
            labels = np.append(labels, inputs["labels"].detach().cpu().numpy(), axis=0)
thomwolf's avatar
thomwolf committed
120

Julien Chaumond's avatar
Julien Chaumond committed
121
        tot_tokens += inputs["attention_mask"].float().detach().sum().data
thomwolf's avatar
thomwolf committed
122
123
124
125

    # Normalize
    attn_entropy /= tot_tokens
    head_importance /= tot_tokens
thomwolf's avatar
thomwolf committed
126
127
128
    # Layerwise importance normalization
    if not args.dont_normalize_importance_by_layer:
        exponent = 2
129
        norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
thomwolf's avatar
thomwolf committed
130
131
132
        head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20

    if not args.dont_normalize_global_importance:
thomwolf's avatar
thomwolf committed
133
134
        head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())

thomwolf's avatar
thomwolf committed
135
    # Print/save matrices
136
137
    np.save(os.path.join(args.output_dir, "attn_entropy.npy"), attn_entropy.detach().cpu().numpy())
    np.save(os.path.join(args.output_dir, "head_importance.npy"), head_importance.detach().cpu().numpy())
thomwolf's avatar
thomwolf committed
138
139
140
141
142
143
144

    logger.info("Attention entropies")
    print_2d_tensor(attn_entropy)
    logger.info("Head importance scores")
    print_2d_tensor(head_importance)
    logger.info("Head ranked by importance scores")
    head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
145
146
147
    head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
        head_importance.numel(), device=args.device
    )
thomwolf's avatar
thomwolf committed
148
149
150
    head_ranks = head_ranks.view_as(head_importance)
    print_2d_tensor(head_ranks)

thomwolf's avatar
thomwolf committed
151
    return attn_entropy, head_importance, preds, labels
thomwolf's avatar
thomwolf committed
152

thomwolf's avatar
thomwolf committed
153

thomwolf's avatar
thomwolf committed
154
def mask_heads(args, model, eval_dataloader):
Lysandre's avatar
Lysandre committed
155
156
    """This method shows how to mask head (set some heads to zero), to test the effect on the network,
    based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650)
thomwolf's avatar
thomwolf committed
157
158
159
    """
    _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
Julien Chaumond's avatar
Julien Chaumond committed
160
    original_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
thomwolf's avatar
thomwolf committed
161
162
163
164
165
166
167
    logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)

    new_head_mask = torch.ones_like(head_importance)
    num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))

    current_score = original_score
    while current_score >= original_score * args.masking_threshold:
168
        head_mask = new_head_mask.clone()  # save current head mask
thomwolf's avatar
thomwolf committed
169
        # heads from least important to most - keep only not-masked heads
170
        head_importance[head_mask == 0.0] = float("Inf")
thomwolf's avatar
thomwolf committed
171
172
173
174
175
176
177
178
179
180
181
        current_heads_to_mask = head_importance.view(-1).sort()[1]

        if len(current_heads_to_mask) <= num_to_mask:
            break

        # mask heads
        current_heads_to_mask = current_heads_to_mask[:num_to_mask]
        logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
        new_head_mask = new_head_mask.view(-1)
        new_head_mask[current_heads_to_mask] = 0.0
        new_head_mask = new_head_mask.view_as(head_mask)
182
        new_head_mask = new_head_mask.clone().detach()
thomwolf's avatar
thomwolf committed
183
184
185
        print_2d_tensor(new_head_mask)

        # Compute metric and head importance again
186
187
188
        _, head_importance, preds, labels = compute_heads_importance(
            args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask
        )
thomwolf's avatar
thomwolf committed
189
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
Julien Chaumond's avatar
Julien Chaumond committed
190
        current_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
191
        logger.info(
192
            "Masking: current score: %f, remaining heads %d (%.1f percents)",
193
194
195
196
            current_score,
            new_head_mask.sum(),
            new_head_mask.sum() / new_head_mask.numel() * 100,
        )
thomwolf's avatar
thomwolf committed
197
198
199

    logger.info("Final head mask")
    print_2d_tensor(head_mask)
200
    np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy())
thomwolf's avatar
thomwolf committed
201
202
203
204
205

    return head_mask


def prune_heads(args, model, eval_dataloader, head_mask):
Lysandre's avatar
Lysandre committed
206
207
    """This method shows how to prune head (remove heads weights) based on
    the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650)
thomwolf's avatar
thomwolf committed
208
209
210
211
    """
    # Try pruning and test time speedup
    # Pruning is like masking but we actually remove the masked weights
    before_time = datetime.now()
212
213
214
    _, _, preds, labels = compute_heads_importance(
        args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask
    )
thomwolf's avatar
thomwolf committed
215
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
Julien Chaumond's avatar
Julien Chaumond committed
216
    score_masking = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
thomwolf's avatar
thomwolf committed
217
218
219
    original_time = datetime.now() - before_time

    original_num_params = sum(p.numel() for p in model.parameters())
220
221
222
223
    heads_to_prune = dict(
        (layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask))
    )

thomwolf's avatar
thomwolf committed
224
225
226
227
228
    assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
    model.prune_heads(heads_to_prune)
    pruned_num_params = sum(p.numel() for p in model.parameters())

    before_time = datetime.now()
229
    _, _, preds, labels = compute_heads_importance(
230
231
232
233
234
235
236
        args,
        model,
        eval_dataloader,
        compute_entropy=False,
        compute_importance=False,
        head_mask=None,
        actually_pruned=True,
237
    )
thomwolf's avatar
thomwolf committed
238
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
Julien Chaumond's avatar
Julien Chaumond committed
239
    score_pruning = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
thomwolf's avatar
thomwolf committed
240
241
    new_time = datetime.now() - before_time

242
243
244
245
246
247
    logger.info(
        "Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)",
        original_num_params,
        pruned_num_params,
        pruned_num_params / original_num_params * 100,
    )
thomwolf's avatar
thomwolf committed
248
    logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
249
    logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time / new_time * 100)
thomwolf's avatar
thomwolf committed
250
251
252


def main():
253
    parser = argparse.ArgumentParser()
254
    # Required parameters
255
256
257
258
259
260
261
262
263
264
265
266
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
Julien Chaumond's avatar
Julien Chaumond committed
267
        help="Path to pretrained model or model identifier from huggingface.co/models",
268
269
270
271
272
273
    )
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
Julien Chaumond's avatar
Julien Chaumond committed
274
        help="The name of the task to train selected in the list: " + ", ".join(glue_processors.keys()),
275
276
277
278
279
280
281
282
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
thomwolf's avatar
thomwolf committed
283

284
    # Other parameters
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--cache_dir",
Julien Chaumond's avatar
Julien Chaumond committed
299
        default=None,
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )
    parser.add_argument(
        "--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
    )
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )

    parser.add_argument(
        "--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
    )
    parser.add_argument(
        "--dont_normalize_global_importance",
        action="store_true",
        help="Don't normalize all importance scores between 0 and 1",
    )

    parser.add_argument(
        "--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
    )
    parser.add_argument(
        "--masking_threshold",
        default=0.9,
        type=float,
        help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
    )
    parser.add_argument(
        "--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
    )
    parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")

    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, sequences shorter padded.",
    )
thomwolf's avatar
thomwolf committed
343
344
    parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")

345
    parser.add_argument("--seed", type=int, default=42)
thomwolf's avatar
thomwolf committed
346
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
347
348
349
    parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
    parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
350
351
    args = parser.parse_args()

thomwolf's avatar
thomwolf committed
352
353
354
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
355

thomwolf's avatar
thomwolf committed
356
357
358
359
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
360
    # Setup devices and distributed training
thomwolf's avatar
thomwolf committed
361
362
    if args.local_rank == -1 or args.no_cuda:
        args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
363
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
thomwolf's avatar
thomwolf committed
364
365
366
    else:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
thomwolf's avatar
thomwolf committed
367
        args.n_gpu = 1
368
        torch.distributed.init_process_group(backend="nccl")  # Initializes the distributed backend
thomwolf's avatar
thomwolf committed
369

thomwolf's avatar
thomwolf committed
370
    # Setup logging
371
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
thomwolf's avatar
thomwolf committed
372
    logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1)))
373
374
375
376
377
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
378

thomwolf's avatar
thomwolf committed
379
    # Set seeds
Julien Chaumond's avatar
Julien Chaumond committed
380
    set_seed(args.seed)
thomwolf's avatar
thomwolf committed
381
382

    # Prepare GLUE task
thomwolf's avatar
thomwolf committed
383
    args.task_name = args.task_name.lower()
Julien Chaumond's avatar
Julien Chaumond committed
384
    if args.task_name not in glue_processors:
thomwolf's avatar
thomwolf committed
385
        raise ValueError("Task not found: %s" % (args.task_name))
Julien Chaumond's avatar
Julien Chaumond committed
386
387
    processor = glue_processors[args.task_name]()
    args.output_mode = glue_output_modes[args.task_name]
thomwolf's avatar
thomwolf committed
388
    label_list = processor.get_labels()
thomwolf's avatar
thomwolf committed
389
    num_labels = len(label_list)
thomwolf's avatar
thomwolf committed
390

thomwolf's avatar
thomwolf committed
391
    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
392
393
394
395
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
thomwolf's avatar
thomwolf committed
396

Julien Chaumond's avatar
Julien Chaumond committed
397
    config = AutoConfig.from_pretrained(
398
399
400
401
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        output_attentions=True,
Julien Chaumond's avatar
Julien Chaumond committed
402
        cache_dir=args.cache_dir,
403
    )
Julien Chaumond's avatar
Julien Chaumond committed
404
    tokenizer = AutoTokenizer.from_pretrained(
Lysandre's avatar
Lysandre committed
405
406
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        cache_dir=args.cache_dir,
407
    )
Julien Chaumond's avatar
Julien Chaumond committed
408
    model = AutoModelForSequenceClassification.from_pretrained(
409
410
411
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
Julien Chaumond's avatar
Julien Chaumond committed
412
        cache_dir=args.cache_dir,
413
    )
thomwolf's avatar
thomwolf committed
414
415

    # Distributed and parallel training
thomwolf's avatar
thomwolf committed
416
    model.to(args.device)
thomwolf's avatar
thomwolf committed
417
    if args.local_rank != -1:
418
419
420
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
thomwolf's avatar
thomwolf committed
421
422
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
423

thomwolf's avatar
thomwolf committed
424
    # Print/save training arguments
Julien Chaumond's avatar
Julien Chaumond committed
425
    os.makedirs(args.output_dir, exist_ok=True)
426
    torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
thomwolf's avatar
thomwolf committed
427
    logger.info("Training/evaluation parameters %s", args)
thomwolf's avatar
thomwolf committed
428

thomwolf's avatar
thomwolf committed
429
    # Prepare dataset for the GLUE task
430
    eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
thomwolf's avatar
thomwolf committed
431
    if args.data_subset > 0:
Julien Chaumond's avatar
Julien Chaumond committed
432
433
434
        eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
    eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
    eval_dataloader = DataLoader(
435
        eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=default_data_collator
Julien Chaumond's avatar
Julien Chaumond committed
436
    )
thomwolf's avatar
thomwolf committed
437
438

    # Compute head entropy and importance score
thomwolf's avatar
thomwolf committed
439
    compute_heads_importance(args, model, eval_dataloader)
thomwolf's avatar
thomwolf committed
440

thomwolf's avatar
thomwolf committed
441
442
    # Try head masking (set heads to zero until the score goes under a threshole)
    # and head pruning (remove masked heads and see the effect on the network)
thomwolf's avatar
thomwolf committed
443
    if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
thomwolf's avatar
thomwolf committed
444
445
        head_mask = mask_heads(args, model, eval_dataloader)
        prune_heads(args, model, eval_dataloader, head_mask)
thomwolf's avatar
thomwolf committed
446

thomwolf's avatar
thomwolf committed
447

448
if __name__ == "__main__":
thomwolf's avatar
thomwolf committed
449
    main()