bertology.py 15 KB
Newer Older
1
#!/usr/bin/env python3
thomwolf's avatar
thomwolf committed
2
import os
3
4
import argparse
import logging
thomwolf's avatar
thomwolf committed
5
from datetime import timedelta, datetime
thomwolf's avatar
thomwolf committed
6
from tqdm import tqdm
7
8
9

import numpy as np

thomwolf's avatar
thomwolf committed
10
11
12
13
14
import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset, Subset
from torch.utils.data.distributed import DistributedSampler
from torch.nn import CrossEntropyLoss, MSELoss

thomwolf's avatar
thomwolf committed
15
from pytorch_pretrained_bert import BertForSequenceClassification, BertTokenizer
16

thomwolf's avatar
thomwolf committed
17
18
19
from run_classifier_dataset_utils import processors, output_modes, convert_examples_to_features, compute_metrics


20
21
logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

def entropy(p):
    plogp = p * torch.log(p)
    plogp[p == 0] = 0
    return -plogp.sum(dim=-1)

def print_1d_tensor(tensor, prefix=""):
    if tensor.dtype != torch.long:
        logger.info(prefix + "\t".join(f"{x:.5f}" for x in tensor.cpu().data))
    else:
        logger.info(prefix + "\t".join(f"{x:d}" for x in tensor.cpu().data))

def print_2d_tensor(tensor):
    logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
    for row in range(len(tensor)):
        print_1d_tensor(tensor[row], prefix=f"layer {row + 1}:\t")

thomwolf's avatar
thomwolf committed
39
def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None):
thomwolf's avatar
thomwolf committed
40
41
42
43
44
    """ Example on how to use model outputs to compute:
        - head attention entropy (activated by setting output_attentions=True when we created the model
        - head importance scores according to http://arxiv.org/abs/1905.10650
            (activated by setting keep_multihead_output=True when we created the model)
    """
thomwolf's avatar
thomwolf committed
45
46
47
48
49
50
51
52
    # Prepare our tensors
    n_layers, n_heads = model.bert.config.num_hidden_layers, model.bert.config.num_attention_heads
    head_importance = torch.zeros(n_layers, n_heads).to(args.device)
    attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
    preds = None
    labels = None
    tot_tokens = 0.0

thomwolf's avatar
thomwolf committed
53
54
55
56
    for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
        batch = tuple(t.to(args.device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids = batch

thomwolf's avatar
thomwolf committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        # Do a forward pass (not in torch.no_grad() since we need gradients for importance score - see below)
        all_attentions, logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, head_mask=head_mask)

        if compute_entropy:
            # Update head attention entropy
            for layer, attn in enumerate(all_attentions):
                masked_entropy = entropy(attn.detach()) * input_mask.float().unsqueeze(1)
                attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()

        if compute_importance:
            # Update head importance scores with regards to our loss
            # First, backpropagate to populate the gradients
            if args.output_mode == "classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, args.num_labels), label_ids.view(-1))
            elif args.output_mode == "regression":
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), label_ids.view(-1))
            loss.backward()
            # Second, compute importance scores according to http://arxiv.org/abs/1905.10650
            multihead_outputs = model.bert.get_multihead_outputs()
            for layer, mh_layer_output in enumerate(multihead_outputs):
                dot = torch.einsum("bhli,bhli->bhl", [mh_layer_output.grad, mh_layer_output])
                head_importance[layer] += dot.abs().sum(-1).sum(0).detach()

        # Also store our logits/labels if we want to compute metrics afterwards
        if preds is None:
            preds = logits.detach().cpu().numpy()
            labels = label_ids.detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            labels = np.append(labels, label_ids.detach().cpu().numpy(), axis=0)
thomwolf's avatar
thomwolf committed
89
90
91
92
93
94
95
96
97

        tot_tokens += input_mask.float().detach().sum().data

    # Normalize
    attn_entropy /= tot_tokens
    head_importance /= tot_tokens
    if args.normalize_importance:
        head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())

thomwolf's avatar
thomwolf committed
98
    return attn_entropy, head_importance, preds, labels
thomwolf's avatar
thomwolf committed
99

100
101
def run_model():
    parser = argparse.ArgumentParser()
thomwolf's avatar
thomwolf committed
102
103
104
105
106
107
108
109
110
    parser.add_argument('--model_name_or_path', type=str, default='bert-base-cased-finetuned-mrpc', help='pretrained model name or path to local checkpoint')
    parser.add_argument("--task_name", type=str, default='mrpc', help="The name of the task to train.")
    parser.add_argument("--data_dir", type=str, required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--output_dir", type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances.")
    parser.add_argument("--overwrite_output_dir", action='store_true', help="Whether to overwrite data in output directory")

    parser.add_argument("--normalize_importance", action='store_true', help="Whether to normalize importance score between 0 and 1")

thomwolf's avatar
thomwolf committed
111
112
113
114
115
    parser.add_argument("--try_masking", action='store_true', help="Whether to try to mask head until a threshold of accuracy.")
    parser.add_argument("--masking_threshold", default=0.9, type=float, help="masking threshold in term of metrics"
                                                                             "(stop masking when metric < threshold * original metric value).")
    parser.add_argument("--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step.")
    parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")
thomwolf's avatar
thomwolf committed
116
117
118
119
120
121

    parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")

122
    parser.add_argument("--seed", type=int, default=42)
thomwolf's avatar
thomwolf committed
123
124
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available")
125
126
    args = parser.parse_args()

thomwolf's avatar
thomwolf committed
127
    # Setup devices and distributed training
thomwolf's avatar
thomwolf committed
128
129
130
131
132
133
134
    if args.local_rank == -1 or args.no_cuda:
        args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        n_gpu = 1
thomwolf's avatar
thomwolf committed
135
        torch.distributed.init_process_group(backend='nccl')  # Initializes the distributed backend
thomwolf's avatar
thomwolf committed
136

thomwolf's avatar
thomwolf committed
137
    # Setup logging
thomwolf's avatar
thomwolf committed
138
    logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
thomwolf's avatar
thomwolf committed
139
    logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, n_gpu, bool(args.local_rank != -1)))
140

thomwolf's avatar
thomwolf committed
141
142
143
144
145
146
147
148
149
150
    # Set seeds
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed(args.seed)

    # Prepare GLUE task
    task_name = args.task_name.lower()
    processor = processors[task_name]()
    label_list = processor.get_labels()
thomwolf's avatar
thomwolf committed
151
152
    args.output_mode = output_modes[task_name]
    args.num_labels = len(label_list)
thomwolf's avatar
thomwolf committed
153
154
155
156
157
158
159
160
161
162

    # Prepare output directory
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    # Load model & tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # Make sure only one distributed process download model & vocab
163
    tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
thomwolf's avatar
thomwolf committed
164
165
166
167
168
169

    # Load a model with all BERTology options on:
    #   output_attentions => will output attention weights
    #   keep_multihead_output => will store gradient of attention head outputs for head importance computation
    #       see: http://arxiv.org/abs/1905.10650
    model = BertForSequenceClassification.from_pretrained(args.model_name_or_path,
thomwolf's avatar
thomwolf committed
170
                                                          num_labels=args.num_labels,
thomwolf's avatar
thomwolf committed
171
172
173
174
                                                          output_attentions=True,
                                                          keep_multihead_output=True)
    if args.local_rank == 0:
        torch.distributed.barrier()  # Make sure only one distributed process download model & vocab
thomwolf's avatar
thomwolf committed
175
    model.to(args.device)
thomwolf's avatar
thomwolf committed
176
177
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
178
179
    model.eval()

thomwolf's avatar
thomwolf committed
180
181
182
183
184
185
186
    # Prepare dataset for the GLUE task
    eval_examples = processor.get_dev_examples(args.data_dir)
    cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format(
        list(filter(None, args.model_name_or_path.split('/'))).pop(), str(args.max_seq_length), str(task_name)))
    try:
        eval_features = torch.load(cached_eval_features_file)
    except:
thomwolf's avatar
thomwolf committed
187
        eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, args.output_mode)
thomwolf's avatar
thomwolf committed
188
189
190
        if args.local_rank in [-1, 0]:
            logger.info("Saving eval features to cache file %s", cached_eval_features_file)
            torch.save(eval_features, cached_eval_features_file)
191

thomwolf's avatar
thomwolf committed
192
193
194
    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
thomwolf's avatar
thomwolf committed
195
    all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long if args.output_mode == "classification" else torch.float)
thomwolf's avatar
thomwolf committed
196
197
198
199
200
201
202
203
204
205
206
207
208
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

    if args.data_subset > 0:
        eval_data = Subset(eval_data, list(range(args.data_subset)))

    eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)

    # Print/save training arguments
    print(args)
    torch.save(args, os.path.join(args.output_dir, 'run_args.bin'))

    # Compute head entropy and importance score
thomwolf's avatar
thomwolf committed
209
    attn_entropy, head_importance, _, _ = compute_heads_importance(args, model, eval_dataloader)
210

thomwolf's avatar
thomwolf committed
211
212
213
214
215
216
217
218
219
    # Print/save matrices
    np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy)
    np.save(os.path.join(args.output_dir, 'head_importance.npy'), head_importance)

    logger.info("Attention entropies")
    print_2d_tensor(attn_entropy)
    logger.info("Head importance scores")
    print_2d_tensor(head_importance)
    logger.info("Head ranked by importance scores")
thomwolf's avatar
thomwolf committed
220
    head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
thomwolf's avatar
thomwolf committed
221
    head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel())
thomwolf's avatar
thomwolf committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    head_ranks = head_ranks.view_as(head_importance)
    print_2d_tensor(head_ranks)

    # Do masking if we want to
    if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
        _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
        original_score = compute_metrics(task_name, preds, labels)[args.metric_name]
        logger.info("Pruning: original score: %f", original_score)

        new_head_mask = torch.ones_like(head_importance)
        num_to_mask = int(new_head_mask.numel() * args.masking_amount)

        current_score = original_score
        while current_score >= original_score * args.masking_threshold:
            head_mask = new_head_mask
            # heads from most important to least
            heads_to_mask = head_importance.view(-1).sort(descending=True)[1]
            # keep only not-masked heads
            heads_to_mask = heads_to_mask[head_mask.view(-1).nonzero()][:, 0]

            if len(heads_to_mask) <= num_to_mask:
                break

            # mask heads
            heads_to_mask = heads_to_mask[-num_to_mask:]
            new_head_mask = head_mask.view(-1)
            new_head_mask[heads_to_mask] = 0.0
            new_head_mask = new_head_mask.view_as(head_importance)
            print_2d_tensor(new_head_mask)

            # Compute metric and head importance again
            _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask)
            preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
            current_score = compute_metrics(task_name, preds, labels)[args.metric_name]
            logger.info("Masking: current score: %f, remaning heads %.1f percents", current_score, head_mask.sum()/head_mask.numel() * 100)

        # Try pruning and test time speedup
        # Pruning is like masking but we actually remove the masked weights
        before_time = datetime.now()
        _, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
                                                       compute_entropy=False, compute_importance=False, head_mask=head_mask)
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
        score_masking = compute_metrics(task_name, preds, labels)[args.metric_name]
        original_time = datetime.now() - before_time

        heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask)))
        assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
        model.bert.prune_heads(heads_to_prune)

        before_time = datetime.now()
        _, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
                                                       compute_entropy=False, compute_importance=False, head_mask=None)
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
        score_pruning = compute_metrics(task_name, preds, labels)[args.metric_name]
        new_time = datetime.now() - before_time

        logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
        logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time/new_time * 100)
thomwolf's avatar
thomwolf committed
281
282
283

if __name__ == '__main__':
    run_model()