"...lm-evaluation-harness.git" did not exist on "80a739eb37a6bfd7f0ec489ee86e35a684af1c76"
bertology.py 16.7 KB
Newer Older
1
#!/usr/bin/env python3
thomwolf's avatar
thomwolf committed
2
import os
3
4
import argparse
import logging
thomwolf's avatar
thomwolf committed
5
from datetime import timedelta, datetime
thomwolf's avatar
thomwolf committed
6
from tqdm import tqdm
7
8
9

import numpy as np

thomwolf's avatar
thomwolf committed
10
11
12
13
14
import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset, Subset
from torch.utils.data.distributed import DistributedSampler
from torch.nn import CrossEntropyLoss, MSELoss

thomwolf's avatar
thomwolf committed
15
from pytorch_pretrained_bert import BertForSequenceClassification, BertTokenizer
16

thomwolf's avatar
thomwolf committed
17
18
19
from run_classifier_dataset_utils import processors, output_modes, convert_examples_to_features, compute_metrics


20
21
logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
22
23
24
25
26
27

def entropy(p):
    plogp = p * torch.log(p)
    plogp[p == 0] = 0
    return -plogp.sum(dim=-1)

thomwolf's avatar
thomwolf committed
28

thomwolf's avatar
thomwolf committed
29
30
31
32
33
34
def print_1d_tensor(tensor, prefix=""):
    if tensor.dtype != torch.long:
        logger.info(prefix + "\t".join(f"{x:.5f}" for x in tensor.cpu().data))
    else:
        logger.info(prefix + "\t".join(f"{x:d}" for x in tensor.cpu().data))

thomwolf's avatar
thomwolf committed
35

thomwolf's avatar
thomwolf committed
36
37
38
39
40
def print_2d_tensor(tensor):
    logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
    for row in range(len(tensor)):
        print_1d_tensor(tensor[row], prefix=f"layer {row + 1}:\t")

thomwolf's avatar
thomwolf committed
41

thomwolf's avatar
thomwolf committed
42
def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None):
thomwolf's avatar
thomwolf committed
43
44
45
46
47
    """ Example on how to use model outputs to compute:
        - head attention entropy (activated by setting output_attentions=True when we created the model
        - head importance scores according to http://arxiv.org/abs/1905.10650
            (activated by setting keep_multihead_output=True when we created the model)
    """
thomwolf's avatar
thomwolf committed
48
49
50
51
52
53
54
55
    # Prepare our tensors
    n_layers, n_heads = model.bert.config.num_hidden_layers, model.bert.config.num_attention_heads
    head_importance = torch.zeros(n_layers, n_heads).to(args.device)
    attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
    preds = None
    labels = None
    tot_tokens = 0.0

thomwolf's avatar
thomwolf committed
56
57
58
59
    for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
        batch = tuple(t.to(args.device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids = batch

thomwolf's avatar
thomwolf committed
60
        # Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
thomwolf's avatar
thomwolf committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        all_attentions, logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, head_mask=head_mask)

        if compute_entropy:
            # Update head attention entropy
            for layer, attn in enumerate(all_attentions):
                masked_entropy = entropy(attn.detach()) * input_mask.float().unsqueeze(1)
                attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()

        if compute_importance:
            # Update head importance scores with regards to our loss
            # First, backpropagate to populate the gradients
            if args.output_mode == "classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, args.num_labels), label_ids.view(-1))
            elif args.output_mode == "regression":
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), label_ids.view(-1))
            loss.backward()
            # Second, compute importance scores according to http://arxiv.org/abs/1905.10650
            multihead_outputs = model.bert.get_multihead_outputs()
            for layer, mh_layer_output in enumerate(multihead_outputs):
                dot = torch.einsum("bhli,bhli->bhl", [mh_layer_output.grad, mh_layer_output])
                head_importance[layer] += dot.abs().sum(-1).sum(0).detach()

        # Also store our logits/labels if we want to compute metrics afterwards
        if preds is None:
            preds = logits.detach().cpu().numpy()
            labels = label_ids.detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            labels = np.append(labels, label_ids.detach().cpu().numpy(), axis=0)
thomwolf's avatar
thomwolf committed
92
93
94
95
96
97

        tot_tokens += input_mask.float().detach().sum().data

    # Normalize
    attn_entropy /= tot_tokens
    head_importance /= tot_tokens
thomwolf's avatar
thomwolf committed
98
99
100
101
102
103
104
    # Layerwise importance normalization
    if not args.dont_normalize_importance_by_layer:
        exponent = 2
        norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1/exponent)
        head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20

    if not args.dont_normalize_global_importance:
thomwolf's avatar
thomwolf committed
105
106
        head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())

thomwolf's avatar
thomwolf committed
107
    return attn_entropy, head_importance, preds, labels
thomwolf's avatar
thomwolf committed
108

thomwolf's avatar
thomwolf committed
109

110
111
def run_model():
    parser = argparse.ArgumentParser()
thomwolf's avatar
thomwolf committed
112
113
114
115
116
117
118
    parser.add_argument('--model_name_or_path', type=str, default='bert-base-cased-finetuned-mrpc', help='pretrained model name or path to local checkpoint')
    parser.add_argument("--task_name", type=str, default='mrpc', help="The name of the task to train.")
    parser.add_argument("--data_dir", type=str, required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--output_dir", type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances.")
    parser.add_argument("--overwrite_output_dir", action='store_true', help="Whether to overwrite data in output directory")

thomwolf's avatar
thomwolf committed
119
120
    parser.add_argument("--dont_normalize_importance_by_layer", action='store_true', help="Don't normalize importance score by layers")
    parser.add_argument("--dont_normalize_global_importance", action='store_true', help="Don't normalize all importance scores between 0 and 1")
thomwolf's avatar
thomwolf committed
121

thomwolf's avatar
thomwolf committed
122
123
124
125
126
    parser.add_argument("--try_masking", action='store_true', help="Whether to try to mask head until a threshold of accuracy.")
    parser.add_argument("--masking_threshold", default=0.9, type=float, help="masking threshold in term of metrics"
                                                                             "(stop masking when metric < threshold * original metric value).")
    parser.add_argument("--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step.")
    parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")
thomwolf's avatar
thomwolf committed
127
128
129
130
131
132

    parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")

133
    parser.add_argument("--seed", type=int, default=42)
thomwolf's avatar
thomwolf committed
134
135
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available")
thomwolf's avatar
thomwolf committed
136
137
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
138
139
    args = parser.parse_args()

thomwolf's avatar
thomwolf committed
140
141
142
143
144
145
146
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

thomwolf's avatar
thomwolf committed
147
    # Setup devices and distributed training
thomwolf's avatar
thomwolf committed
148
149
150
151
152
153
154
    if args.local_rank == -1 or args.no_cuda:
        args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        n_gpu = 1
thomwolf's avatar
thomwolf committed
155
        torch.distributed.init_process_group(backend='nccl')  # Initializes the distributed backend
thomwolf's avatar
thomwolf committed
156

thomwolf's avatar
thomwolf committed
157
    # Setup logging
thomwolf's avatar
thomwolf committed
158
    logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
thomwolf's avatar
thomwolf committed
159
    logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, n_gpu, bool(args.local_rank != -1)))
160

thomwolf's avatar
thomwolf committed
161
162
163
164
165
166
167
168
169
170
    # Set seeds
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed(args.seed)

    # Prepare GLUE task
    task_name = args.task_name.lower()
    processor = processors[task_name]()
    label_list = processor.get_labels()
thomwolf's avatar
thomwolf committed
171
172
    args.output_mode = output_modes[task_name]
    args.num_labels = len(label_list)
thomwolf's avatar
thomwolf committed
173
174
175
176
177
178
179
180
181
182

    # Prepare output directory
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    # Load model & tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # Make sure only one distributed process download model & vocab
183
    tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
thomwolf's avatar
thomwolf committed
184
185
186
187
188
189

    # Load a model with all BERTology options on:
    #   output_attentions => will output attention weights
    #   keep_multihead_output => will store gradient of attention head outputs for head importance computation
    #       see: http://arxiv.org/abs/1905.10650
    model = BertForSequenceClassification.from_pretrained(args.model_name_or_path,
thomwolf's avatar
thomwolf committed
190
                                                          num_labels=args.num_labels,
thomwolf's avatar
thomwolf committed
191
192
193
194
                                                          output_attentions=True,
                                                          keep_multihead_output=True)
    if args.local_rank == 0:
        torch.distributed.barrier()  # Make sure only one distributed process download model & vocab
thomwolf's avatar
thomwolf committed
195
    model.to(args.device)
thomwolf's avatar
thomwolf committed
196
197
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
198
199
    model.eval()

thomwolf's avatar
thomwolf committed
200
201
202
203
204
205
206
    # Prepare dataset for the GLUE task
    eval_examples = processor.get_dev_examples(args.data_dir)
    cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format(
        list(filter(None, args.model_name_or_path.split('/'))).pop(), str(args.max_seq_length), str(task_name)))
    try:
        eval_features = torch.load(cached_eval_features_file)
    except:
thomwolf's avatar
thomwolf committed
207
        eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, args.output_mode)
thomwolf's avatar
thomwolf committed
208
209
210
        if args.local_rank in [-1, 0]:
            logger.info("Saving eval features to cache file %s", cached_eval_features_file)
            torch.save(eval_features, cached_eval_features_file)
211

thomwolf's avatar
thomwolf committed
212
213
214
    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
thomwolf's avatar
thomwolf committed
215
    all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long if args.output_mode == "classification" else torch.float)
thomwolf's avatar
thomwolf committed
216
217
218
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

    if args.data_subset > 0:
thomwolf's avatar
thomwolf committed
219
        eval_data = Subset(eval_data, list(range(min(args.data_subset, len(eval_data)))))
thomwolf's avatar
thomwolf committed
220
221
222
223
224
225
226
227
228

    eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)

    # Print/save training arguments
    print(args)
    torch.save(args, os.path.join(args.output_dir, 'run_args.bin'))

    # Compute head entropy and importance score
thomwolf's avatar
thomwolf committed
229
    attn_entropy, head_importance, _, _ = compute_heads_importance(args, model, eval_dataloader)
230

thomwolf's avatar
thomwolf committed
231
    # Print/save matrices
thomwolf's avatar
thomwolf committed
232
233
    np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy.detach().cpu().numpy())
    np.save(os.path.join(args.output_dir, 'head_importance.npy'), head_importance.detach().cpu().numpy())
thomwolf's avatar
thomwolf committed
234
235
236
237
238
239

    logger.info("Attention entropies")
    print_2d_tensor(attn_entropy)
    logger.info("Head importance scores")
    print_2d_tensor(head_importance)
    logger.info("Head ranked by importance scores")
thomwolf's avatar
thomwolf committed
240
    head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
thomwolf's avatar
thomwolf committed
241
    head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel(), device=args.device)
thomwolf's avatar
thomwolf committed
242
243
244
245
246
247
248
249
    head_ranks = head_ranks.view_as(head_importance)
    print_2d_tensor(head_ranks)

    # Do masking if we want to
    if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
        _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
        original_score = compute_metrics(task_name, preds, labels)[args.metric_name]
thomwolf's avatar
thomwolf committed
250
        logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
thomwolf's avatar
thomwolf committed
251
252

        new_head_mask = torch.ones_like(head_importance)
thomwolf's avatar
thomwolf committed
253
        num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))
thomwolf's avatar
thomwolf committed
254
255
256

        current_score = original_score
        while current_score >= original_score * args.masking_threshold:
thomwolf's avatar
thomwolf committed
257
258
259
260
            head_mask = new_head_mask.clone() # save current head mask
            # heads from least important to most - keep only not-masked heads
            head_importance[head_mask == 0.0] = float('Inf')
            current_heads_to_mask = head_importance.view(-1).sort()[1]
thomwolf's avatar
thomwolf committed
261

thomwolf's avatar
thomwolf committed
262
            if len(current_heads_to_mask) <= num_to_mask:
thomwolf's avatar
thomwolf committed
263
264
265
                break

            # mask heads
thomwolf's avatar
thomwolf committed
266
267
            current_heads_to_mask = current_heads_to_mask[:num_to_mask]
            logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
thomwolf's avatar
thomwolf committed
268
            new_head_mask = new_head_mask.view(-1)
thomwolf's avatar
thomwolf committed
269
270
            new_head_mask[current_heads_to_mask] = 0.0
            new_head_mask = new_head_mask.view_as(head_mask)
thomwolf's avatar
thomwolf committed
271
272
273
274
275
276
            print_2d_tensor(new_head_mask)

            # Compute metric and head importance again
            _, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask)
            preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
            current_score = compute_metrics(task_name, preds, labels)[args.metric_name]
thomwolf's avatar
thomwolf committed
277
            logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100)
thomwolf's avatar
thomwolf committed
278

thomwolf's avatar
thomwolf committed
279
280
281
282
        logger.info("Final head mask")
        print_2d_tensor(head_mask)
        np.save(os.path.join(args.output_dir, 'head_mask.npy'), head_mask.detach().cpu().numpy())

thomwolf's avatar
thomwolf committed
283
284
285
286
287
288
289
290
291
        # Try pruning and test time speedup
        # Pruning is like masking but we actually remove the masked weights
        before_time = datetime.now()
        _, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
                                                       compute_entropy=False, compute_importance=False, head_mask=head_mask)
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
        score_masking = compute_metrics(task_name, preds, labels)[args.metric_name]
        original_time = datetime.now() - before_time

thomwolf's avatar
thomwolf committed
292
        original_num_params = sum(p.numel() for p in model.parameters())
thomwolf's avatar
thomwolf committed
293
294
295
        heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask)))
        assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
        model.bert.prune_heads(heads_to_prune)
thomwolf's avatar
thomwolf committed
296
        pruned_num_params = sum(p.numel() for p in model.parameters())
thomwolf's avatar
thomwolf committed
297
298
299
300
301
302
303
304

        before_time = datetime.now()
        _, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
                                                       compute_entropy=False, compute_importance=False, head_mask=None)
        preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
        score_pruning = compute_metrics(task_name, preds, labels)[args.metric_name]
        new_time = datetime.now() - before_time

thomwolf's avatar
thomwolf committed
305
        logger.info("Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)", original_num_params, pruned_num_params, pruned_num_params/original_num_params * 100)
thomwolf's avatar
thomwolf committed
306
307
        logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
        logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time/new_time * 100)
thomwolf's avatar
thomwolf committed
308
309
310

if __name__ == '__main__':
    run_model()