movement_pruning_glue.py 4.83 KB
Newer Older
1
import functools
2
import time
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from tqdm import tqdm

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader

from datasets import load_metric, load_dataset
from transformers import (
    BertForSequenceClassification,
    BertTokenizerFast,
    DataCollatorWithPadding,
    set_seed
)

J-shang's avatar
J-shang committed
17
import nni
18
from nni.compression.pytorch.pruning import MovementPruner
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34


task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

35
gradient_accumulation_steps = 8
36
37
38
39
40
41
42
43

# a fake criterion because huggingface output already has loss
def criterion(input, target):
    return input.loss

def trainer(model, optimizer, criterion, train_dataloader):
    model.train()
    counter = 0
44
    for batch in (train_dataloader):
45
46
47
48
49
50
51
52
53
54
        counter += 1
        batch.to(device)
        optimizer.zero_grad()
        outputs = model(**batch)
        # pruner may wrap the criterion, for example, loss = origin_loss + norm(weight), so call criterion to get loss here
        loss = criterion(outputs, None)
        loss = loss / gradient_accumulation_steps
        loss.backward()
        if counter % gradient_accumulation_steps == 0 or counter == len(train_dataloader):
            optimizer.step()
55
56
57
        if counter % 800 == 0:
            print('[{}]: {}'.format(time.asctime(time.localtime(time.time())), counter))
        if counter % 8000 == 0:
58
59
60
61
            print('Step {}: {}'.format(counter // gradient_accumulation_steps, evaluator(model, metric, is_regression, validate_dataloader)))

def evaluator(model, metric, is_regression, eval_dataloader):
    model.eval()
62
    for batch in (eval_dataloader):
63
64
65
66
67
68
69
70
71
72
73
74
75
        batch.to(device)
        outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
        metric.add_batch(
            predictions=predictions,
            references=batch["labels"],
        )
    return metric.compute()

if __name__ == '__main__':
    task_name = 'mnli'
    is_regression = False
    num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
76
77
    train_batch_size = 4
    eval_batch_size = 4
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

    set_seed(1024)

    tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
    sentence1_key, sentence2_key = task_to_keys[task_name]

    # used to preprocess the raw data
    def preprocess_function(examples):
        # Tokenize the texts
        args = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*args, padding=False, max_length=128, truncation=True)

        if "label" in examples:
            # In all cases, rename the column to labels because the model will expect that.
            result["labels"] = examples["label"]
        return result

    raw_datasets = load_dataset('glue', task_name, cache_dir='./data')
    processed_datasets = raw_datasets.map(preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names)

    train_dataset = processed_datasets['train']
    validate_dataset = processed_datasets['validation_matched' if task_name == "mnli" else 'validation']

    data_collator = DataCollatorWithPadding(tokenizer)
    train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
    validate_dataloader = DataLoader(validate_dataset, collate_fn=data_collator, batch_size=eval_batch_size)

    metric = load_metric("glue", task_name)

    model = BertForSequenceClassification.from_pretrained('bert-base-cased', num_labels=num_labels).to(device)

    print('Initial: {}'.format(evaluator(model, metric, is_regression, validate_dataloader)))

    config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder'], 'sparsity': 0.9}]
    p_trainer = functools.partial(trainer, train_dataloader=train_dataloader)
115

J-shang's avatar
J-shang committed
116
117
    # make sure you have used nni.trace to wrap the optimizer class before initialize
    traced_optimizer = nni.trace(Adam)(model.parameters(), lr=2e-5)
118
    pruner = MovementPruner(model, config_list, p_trainer, traced_optimizer, criterion, training_epochs=10,
119
                            warm_up_step=12272, cool_down_beginning_step=110448)
120
121
122
123
124
125
126
127
128

    _, masks = pruner.compress()
    pruner.show_pruned_weights()

    print('Final: {}'.format(evaluator(model, metric, is_regression, validate_dataloader)))

    optimizer = Adam(model.parameters(), lr=2e-5)
    trainer(model, optimizer, criterion, train_dataloader)
    print('After 1 epoch finetuning: {}'.format(evaluator(model, metric, is_regression, validate_dataloader)))