run_bertology.py 17.8 KB
Newer Older
1
#!/usr/bin/env python3
thomwolf's avatar
thomwolf committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2018 CMU and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Bertology: this script shows how you can explore the internals of the models in the library to:
    - compute the entropy of the head attentions
    - compute the importance of each head
    - prune (remove) the low importance head.
    Some parts of this script are adapted from the code of Michel et al. (http://arxiv.org/abs/1905.10650)
    which is available at https://github.com/pmichel31415/are-16-heads-really-better-than-1
"""
22
23
import argparse
import logging
Aymeric Augustin's avatar
Aymeric Augustin committed
24
import os
25
from datetime import datetime
26
27

import numpy as np
thomwolf's avatar
thomwolf committed
28
import torch
29
from torch.utils.data import DataLoader, SequentialSampler, Subset
Aymeric Augustin's avatar
Aymeric Augustin committed
30
31
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
thomwolf's avatar
thomwolf committed
32

Julien Chaumond's avatar
Julien Chaumond committed
33
34
35
36
37
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    GlueDataset,
38
    default_data_collator,
Julien Chaumond's avatar
Julien Chaumond committed
39
40
41
42
43
    glue_compute_metrics,
    glue_output_modes,
    glue_processors,
    set_seed,
)
thomwolf's avatar
thomwolf committed
44

Aymeric Augustin's avatar
Aymeric Augustin committed
45

46
47
logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
48
49

def entropy(p):
thomwolf's avatar
thomwolf committed
50
    """ Compute the entropy of a probability distribution """
thomwolf's avatar
thomwolf committed
51
52
53
54
    plogp = p * torch.log(p)
    plogp[p == 0] = 0
    return -plogp.sum(dim=-1)

thomwolf's avatar
thomwolf committed
55

thomwolf's avatar
thomwolf committed
56
def print_2d_tensor(tensor):
thomwolf's avatar
thomwolf committed
57
    """ Print a 2D tensor """
thomwolf's avatar
thomwolf committed
58
59
    logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
    for row in range(len(tensor)):
thomwolf's avatar
thomwolf committed
60
61
62
63
        if tensor.dtype != torch.long:
            logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data))
        else:
            logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
thomwolf's avatar
thomwolf committed
64

thomwolf's avatar
thomwolf committed
65

66
def compute_heads_importance(
67
    args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False
68
):
Lysandre's avatar
Lysandre committed
69
70
71
    """This method shows how to compute:
    - head attention entropy
    - head importance scores according to http://arxiv.org/abs/1905.10650
thomwolf's avatar
thomwolf committed
72
    """
thomwolf's avatar
thomwolf committed
73
    # Prepare our tensors
Julien Chaumond's avatar
Julien Chaumond committed
74
    n_layers, n_heads = model.config.num_hidden_layers, model.config.num_attention_heads
thomwolf's avatar
thomwolf committed
75
76
    head_importance = torch.zeros(n_layers, n_heads).to(args.device)
    attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
thomwolf's avatar
thomwolf committed
77
78
79

    if head_mask is None:
        head_mask = torch.ones(n_layers, n_heads).to(args.device)
80

thomwolf's avatar
thomwolf committed
81
    head_mask.requires_grad_(requires_grad=True)
82
83
84
85
    # If actually pruned attention multi-head, set head mask to None to avoid shape mismatch
    if actually_pruned:
        head_mask = None

thomwolf's avatar
thomwolf committed
86
87
88
89
    preds = None
    labels = None
    tot_tokens = 0.0

Julien Chaumond's avatar
Julien Chaumond committed
90
91
92
    for step, inputs in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
        for k, v in inputs.items():
            inputs[k] = v.to(args.device)
thomwolf's avatar
thomwolf committed
93

thomwolf's avatar
thomwolf committed
94
        # Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
Julien Chaumond's avatar
Julien Chaumond committed
95
        outputs = model(**inputs, head_mask=head_mask)
96
97
98
99
100
        loss, logits, all_attentions = (
            outputs[0],
            outputs[1],
            outputs[-1],
        )  # Loss and logits are the first, attention the last
thomwolf's avatar
thomwolf committed
101
        loss.backward()  # Backpropagate to populate the gradients in the head mask
thomwolf's avatar
thomwolf committed
102
103
104

        if compute_entropy:
            for layer, attn in enumerate(all_attentions):
Julien Chaumond's avatar
Julien Chaumond committed
105
                masked_entropy = entropy(attn.detach()) * inputs["attention_mask"].float().unsqueeze(1)
thomwolf's avatar
thomwolf committed
106
107
108
                attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()

        if compute_importance:
thomwolf's avatar
thomwolf committed
109
            head_importance += head_mask.grad.abs().detach()
thomwolf's avatar
thomwolf committed
110
111
112
113

        # Also store our logits/labels if we want to compute metrics afterwards
        if preds is None:
            preds = logits.detach().cpu().numpy()
Julien Chaumond's avatar
Julien Chaumond committed
114
            labels = inputs["labels"].detach().cpu().numpy()
thomwolf's avatar
thomwolf committed
115
116
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
Julien Chaumond's avatar
Julien Chaumond committed
117
            labels = np.append(labels, inputs["labels"].detach().cpu().numpy(), axis=0)
thomwolf's avatar
thomwolf committed
118

Julien Chaumond's avatar
Julien Chaumond committed
119
        tot_tokens += inputs["attention_mask"].float().detach().sum().data
thomwolf's avatar
thomwolf committed
120
121
122
123

    # Normalize
    attn_entropy /= tot_tokens
    head_importance /= tot_tokens
thomwolf's avatar
thomwolf committed
124
125
126
    # Layerwise importance normalization
    if not args.dont_normalize_importance_by_layer:
        exponent = 2
127
        norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
thomwolf's avatar
thomwolf committed
128
129
130
        head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20

    if not args.dont_normalize_global_importance:
thomwolf's avatar
thomwolf committed
131
132
        head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())

thomwolf's avatar
thomwolf committed
133
    # Print/save matrices
134
135
    np.save(os.path.join(args.output_dir, "attn_entropy.npy"), attn_entropy.detach().cpu().numpy())
    np.save(os.path.join(args.output_dir, "head_importance.npy"), head_importance.detach().cpu().numpy())
thomwolf's avatar
thomwolf committed
136
137
138
139
140
141
142

    logger.info("Attention entropies")
    print_2d_tensor(attn_entropy)
    logger.info("Head importance scores")
    print_2d_tensor(head_importance)
    logger.info("Head ranked by importance scores")
    head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
143
144
145
    head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
        head_importance.numel(), device=args.device
    )
thomwolf's avatar
thomwolf committed
146
147
148
    head_ranks = head_ranks.view_as(head_importance)
    print_2d_tensor(head_ranks)

thomwolf's avatar
thomwolf committed
149
    return attn_entropy, head_importance, preds, labels
thomwolf's avatar
thomwolf committed
150

thomwolf's avatar
thomwolf committed
151

thomwolf's avatar
thomwolf committed
152
def mask_heads(args, model, eval_dataloader):
Lysandre's avatar
Lysandre committed
153
154
    """This method shows how to mask head (set some heads to zero), to test the effect on the network,
    based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650)
thomwolf's avatar
thomwolf committed
155
156
157
    """
    _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
Julien Chaumond's avatar
Julien Chaumond committed
158
    original_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
thomwolf's avatar
thomwolf committed
159
160
161
162
163
164
165
    logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)

    new_head_mask = torch.ones_like(head_importance)
    num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))

    current_score = original_score
    while current_score >= original_score * args.masking_threshold:
166
        head_mask = new_head_mask.clone()  # save current head mask
thomwolf's avatar
thomwolf committed
167
        # heads from least important to most - keep only not-masked heads
168
        head_importance[head_mask == 0.0] = float("Inf")
thomwolf's avatar
thomwolf committed
169
170
171
172
173
174
175
176
177
178
179
        current_heads_to_mask = head_importance.view(-1).sort()[1]

        if len(current_heads_to_mask) <= num_to_mask:
            break

        # mask heads
        current_heads_to_mask = current_heads_to_mask[:num_to_mask]
        logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
        new_head_mask = new_head_mask.view(-1)
        new_head_mask[current_heads_to_mask] = 0.0
        new_head_mask = new_head_mask.view_as(head_mask)
180
        new_head_mask = new_head_mask.clone().detach()
thomwolf's avatar
thomwolf committed
181
182
183
        print_2d_tensor(new_head_mask)

        # Compute metric and head importance again
184
185
186
        _, head_importance, preds, labels = compute_heads_importance(
            args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask
        )
thomwolf's avatar
thomwolf committed
187
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
Julien Chaumond's avatar
Julien Chaumond committed
188
        current_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
189
        logger.info(
190
            "Masking: current score: %f, remaining heads %d (%.1f percents)",
191
192
193
194
            current_score,
            new_head_mask.sum(),
            new_head_mask.sum() / new_head_mask.numel() * 100,
        )
thomwolf's avatar
thomwolf committed
195
196
197

    logger.info("Final head mask")
    print_2d_tensor(head_mask)
198
    np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy())
thomwolf's avatar
thomwolf committed
199
200
201
202
203

    return head_mask


def prune_heads(args, model, eval_dataloader, head_mask):
Lysandre's avatar
Lysandre committed
204
205
    """This method shows how to prune head (remove heads weights) based on
    the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650)
thomwolf's avatar
thomwolf committed
206
207
208
209
    """
    # Try pruning and test time speedup
    # Pruning is like masking but we actually remove the masked weights
    before_time = datetime.now()
210
211
212
    _, _, preds, labels = compute_heads_importance(
        args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask
    )
thomwolf's avatar
thomwolf committed
213
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
Julien Chaumond's avatar
Julien Chaumond committed
214
    score_masking = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
thomwolf's avatar
thomwolf committed
215
216
217
    original_time = datetime.now() - before_time

    original_num_params = sum(p.numel() for p in model.parameters())
218
219
220
221
    heads_to_prune = dict(
        (layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask))
    )

thomwolf's avatar
thomwolf committed
222
223
224
225
226
    assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
    model.prune_heads(heads_to_prune)
    pruned_num_params = sum(p.numel() for p in model.parameters())

    before_time = datetime.now()
227
    _, _, preds, labels = compute_heads_importance(
228
229
230
231
232
233
234
        args,
        model,
        eval_dataloader,
        compute_entropy=False,
        compute_importance=False,
        head_mask=None,
        actually_pruned=True,
235
    )
thomwolf's avatar
thomwolf committed
236
    preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
Julien Chaumond's avatar
Julien Chaumond committed
237
    score_pruning = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
thomwolf's avatar
thomwolf committed
238
239
    new_time = datetime.now() - before_time

240
241
242
243
244
245
    logger.info(
        "Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)",
        original_num_params,
        pruned_num_params,
        pruned_num_params / original_num_params * 100,
    )
thomwolf's avatar
thomwolf committed
246
    logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
247
    logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time / new_time * 100)
thomwolf's avatar
thomwolf committed
248
249
250


def main():
251
    parser = argparse.ArgumentParser()
252
    # Required parameters
253
254
255
256
257
258
259
260
261
262
263
264
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
Julien Chaumond's avatar
Julien Chaumond committed
265
        help="Path to pretrained model or model identifier from huggingface.co/models",
266
267
268
269
270
271
    )
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
Julien Chaumond's avatar
Julien Chaumond committed
272
        help="The name of the task to train selected in the list: " + ", ".join(glue_processors.keys()),
273
274
275
276
277
278
279
280
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
thomwolf's avatar
thomwolf committed
281

282
    # Other parameters
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name_or_path",
    )
    parser.add_argument(
        "--cache_dir",
Julien Chaumond's avatar
Julien Chaumond committed
297
        default=None,
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )
    parser.add_argument(
        "--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
    )
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )

    parser.add_argument(
        "--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
    )
    parser.add_argument(
        "--dont_normalize_global_importance",
        action="store_true",
        help="Don't normalize all importance scores between 0 and 1",
    )

    parser.add_argument(
        "--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
    )
    parser.add_argument(
        "--masking_threshold",
        default=0.9,
        type=float,
        help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
    )
    parser.add_argument(
        "--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
    )
    parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")

    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, sequences shorter padded.",
    )
thomwolf's avatar
thomwolf committed
341
342
    parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")

343
    parser.add_argument("--seed", type=int, default=42)
thomwolf's avatar
thomwolf committed
344
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
345
346
347
    parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
    parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
348
349
    args = parser.parse_args()

thomwolf's avatar
thomwolf committed
350
351
352
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
353

thomwolf's avatar
thomwolf committed
354
355
356
357
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
358
    # Setup devices and distributed training
thomwolf's avatar
thomwolf committed
359
360
    if args.local_rank == -1 or args.no_cuda:
        args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
361
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
thomwolf's avatar
thomwolf committed
362
363
364
    else:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
thomwolf's avatar
thomwolf committed
365
        args.n_gpu = 1
366
        torch.distributed.init_process_group(backend="nccl")  # Initializes the distributed backend
thomwolf's avatar
thomwolf committed
367

thomwolf's avatar
thomwolf committed
368
    # Setup logging
369
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
thomwolf's avatar
thomwolf committed
370
    logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1)))
371

thomwolf's avatar
thomwolf committed
372
    # Set seeds
Julien Chaumond's avatar
Julien Chaumond committed
373
    set_seed(args.seed)
thomwolf's avatar
thomwolf committed
374
375

    # Prepare GLUE task
thomwolf's avatar
thomwolf committed
376
    args.task_name = args.task_name.lower()
Julien Chaumond's avatar
Julien Chaumond committed
377
    if args.task_name not in glue_processors:
thomwolf's avatar
thomwolf committed
378
        raise ValueError("Task not found: %s" % (args.task_name))
Julien Chaumond's avatar
Julien Chaumond committed
379
380
    processor = glue_processors[args.task_name]()
    args.output_mode = glue_output_modes[args.task_name]
thomwolf's avatar
thomwolf committed
381
    label_list = processor.get_labels()
thomwolf's avatar
thomwolf committed
382
    num_labels = len(label_list)
thomwolf's avatar
thomwolf committed
383

thomwolf's avatar
thomwolf committed
384
    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
385
386
387
388
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
thomwolf's avatar
thomwolf committed
389

Julien Chaumond's avatar
Julien Chaumond committed
390
    config = AutoConfig.from_pretrained(
391
392
393
394
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        output_attentions=True,
Julien Chaumond's avatar
Julien Chaumond committed
395
        cache_dir=args.cache_dir,
396
    )
Julien Chaumond's avatar
Julien Chaumond committed
397
    tokenizer = AutoTokenizer.from_pretrained(
Lysandre's avatar
Lysandre committed
398
399
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        cache_dir=args.cache_dir,
400
    )
Julien Chaumond's avatar
Julien Chaumond committed
401
    model = AutoModelForSequenceClassification.from_pretrained(
402
403
404
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
Julien Chaumond's avatar
Julien Chaumond committed
405
        cache_dir=args.cache_dir,
406
    )
thomwolf's avatar
thomwolf committed
407
408

    # Distributed and parallel training
thomwolf's avatar
thomwolf committed
409
    model.to(args.device)
thomwolf's avatar
thomwolf committed
410
    if args.local_rank != -1:
411
412
413
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
thomwolf's avatar
thomwolf committed
414
415
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
416

thomwolf's avatar
thomwolf committed
417
    # Print/save training arguments
Julien Chaumond's avatar
Julien Chaumond committed
418
    os.makedirs(args.output_dir, exist_ok=True)
419
    torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
thomwolf's avatar
thomwolf committed
420
    logger.info("Training/evaluation parameters %s", args)
thomwolf's avatar
thomwolf committed
421

thomwolf's avatar
thomwolf committed
422
    # Prepare dataset for the GLUE task
423
    eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
thomwolf's avatar
thomwolf committed
424
    if args.data_subset > 0:
Julien Chaumond's avatar
Julien Chaumond committed
425
426
427
        eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
    eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
    eval_dataloader = DataLoader(
428
        eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=default_data_collator
Julien Chaumond's avatar
Julien Chaumond committed
429
    )
thomwolf's avatar
thomwolf committed
430
431

    # Compute head entropy and importance score
thomwolf's avatar
thomwolf committed
432
    compute_heads_importance(args, model, eval_dataloader)
thomwolf's avatar
thomwolf committed
433

thomwolf's avatar
thomwolf committed
434
435
    # Try head masking (set heads to zero until the score goes under a threshole)
    # and head pruning (remove masked heads and see the effect on the network)
thomwolf's avatar
thomwolf committed
436
    if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
thomwolf's avatar
thomwolf committed
437
438
        head_mask = mask_heads(args, model, eval_dataloader)
        prune_heads(args, model, eval_dataloader, head_mask)
thomwolf's avatar
thomwolf committed
439

thomwolf's avatar
thomwolf committed
440

441
if __name__ == "__main__":
thomwolf's avatar
thomwolf committed
442
    main()