pruning_bert_glue.ipynb 26.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
J-shang's avatar
J-shang committed
18
        "\n# Pruning Bert on Task MNLI\n\n## Workable Pruning Process\n\nHere we show an effective transformer pruning process that NNI team has tried, and users can use NNI to discover better processes.\n\nThe entire pruning process can be divided into the following steps:\n\n1. Finetune the pre-trained model on the downstream task. From our experience,\n   the final performance of pruning on the finetuned model is better than pruning directly on the pre-trained model.\n   At the same time, the finetuned model obtained in this step will also be used as the teacher model for the following\n   distillation training.\n2. Pruning the attention layer at first. Here we apply block-sparse on attention layer weight,\n   and directly prune the head (condense the weight) if the head was fully masked.\n   If the head was partially masked, we will not prune it and recover its weight.\n3. Retrain the head-pruned model with distillation. Recover the model precision before pruning FFN layer.\n4. Pruning the FFN layer. Here we apply the output channels pruning on the 1st FFN layer,\n   and the 2nd FFN layer input channels will be pruned due to the pruning of 1st layer output channels.\n5. Retrain the final pruned model with distillation.\n\nDuring the process of pruning transformer, we gained some of the following experiences:\n\n* We using `movement-pruner` in step 2 and `taylor-fo-weight-pruner` in step 4. `movement-pruner` has good performance on attention layers,\n  and `taylor-fo-weight-pruner` method has good performance on FFN layers. These two pruners are all some kinds of gradient-based pruning algorithms,\n  we also try weight-based pruning algorithms like `l1-norm-pruner`, but it doesn't seem to work well in this scenario.\n* Distillation is a good way to recover model precision. In terms of results, usually 1~2% improvement in accuracy can be achieved when we prune bert on mnli task.\n* It is necessary to gradually increase the sparsity rather than reaching a very high sparsity all at once.\n\n## Experiment\n\nThe complete pruning process will take about 8 hours on one A100.\n\n### Preparation\n\nThis section is mainly to get a finetuned model on the downstream task.\nIf you are familiar with how to finetune Bert on GLUE dataset, you can skip this section.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.</p></div>\n"
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dev_mode = True"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Some basic setting.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
J-shang's avatar
J-shang committed
47
        "from pathlib import Path\nfrom typing import Callable, Dict\n\npretrained_model_name_or_path = 'bert-base-uncased'\ntask_name = 'mnli'\nexperiment_id = 'pruning_bert_mnli'\n\n# heads_num and layers_num should align with pretrained_model_name_or_path\nheads_num = 12\nlayers_num = 12\n\n# used to save the experiment log\nlog_dir = Path(f'./pruning_log/{pretrained_model_name_or_path}/{task_name}/{experiment_id}')\nlog_dir.mkdir(parents=True, exist_ok=True)\n\n# used to save the finetuned model and share between different experiemnts with same pretrained_model_name_or_path and task_name\nmodel_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')\nmodel_dir.mkdir(parents=True, exist_ok=True)\n\n# used to save GLUE data\ndata_dir = Path(f'./data')\ndata_dir.mkdir(parents=True, exist_ok=True)\n\n# set seed\nfrom transformers import set_seed\nset_seed(1024)\n\nimport torch\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
48
49
50
51
52
53
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
J-shang's avatar
J-shang committed
54
        "Create dataloaders.\n\n"
55
56
57
58
59
60
61
62
63
64
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
J-shang's avatar
J-shang committed
65
        "from torch.utils.data import DataLoader\n\nfrom datasets import load_dataset\nfrom transformers import BertTokenizerFast, DataCollatorWithPadding\n\ntask_to_keys = {\n    'cola': ('sentence', None),\n    'mnli': ('premise', 'hypothesis'),\n    'mrpc': ('sentence1', 'sentence2'),\n    'qnli': ('question', 'sentence'),\n    'qqp': ('question1', 'question2'),\n    'rte': ('sentence1', 'sentence2'),\n    'sst2': ('sentence', None),\n    'stsb': ('sentence1', 'sentence2'),\n    'wnli': ('sentence1', 'sentence2'),\n}\n\ndef prepare_dataloaders(cache_dir=data_dir, train_batch_size=32, eval_batch_size=32):\n    tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)\n    sentence1_key, sentence2_key = task_to_keys[task_name]\n    data_collator = DataCollatorWithPadding(tokenizer)\n\n    # used to preprocess the raw data\n    def preprocess_function(examples):\n        # Tokenize the texts\n        args = (\n            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])\n        )\n        result = tokenizer(*args, padding=False, max_length=128, truncation=True)\n\n        if 'label' in examples:\n            # In all cases, rename the column to labels because the model will expect that.\n            result['labels'] = examples['label']\n        return result\n\n    raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)\n    for key in list(raw_datasets.keys()):\n        if 'test' in key:\n            raw_datasets.pop(key)\n\n    processed_datasets = raw_datasets.map(preprocess_function, batched=True,\n                                          remove_columns=raw_datasets['train'].column_names)\n\n    train_dataset = processed_datasets['train']\n    if task_name == 'mnli':\n        validation_datasets = {\n            'validation_matched': processed_datasets['validation_matched'],\n            'validation_mismatched': processed_datasets['validation_mismatched']\n        }\n    else:\n        validation_datasets = {\n            'validation': processed_datasets['validation']\n        }\n\n    train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)\n    validation_dataloaders = {\n        val_name: DataLoader(val_dataset, collate_fn=data_collator, batch_size=eval_batch_size) \\\n            for val_name, val_dataset in validation_datasets.items()\n    }\n\n    return train_dataloader, validation_dataloaders\n\n\ntrain_dataloader, validation_dataloaders = prepare_dataloaders()"
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Training function & evaluation function.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
J-shang's avatar
J-shang committed
83
        "import functools\nimport time\n\nimport torch.nn.functional as F\nfrom datasets import load_metric\nfrom transformers.modeling_outputs import SequenceClassifierOutput\n\n\ndef training(model: torch.nn.Module,\n             optimizer: torch.optim.Optimizer,\n             criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n             lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,\n             max_steps: int = None,\n             max_epochs: int = None,\n             train_dataloader: DataLoader = None,\n             distillation: bool = False,\n             teacher_model: torch.nn.Module = None,\n             distil_func: Callable = None,\n             log_path: str = Path(log_dir) / 'training.log',\n             save_best_model: bool = False,\n             save_path: str = None,\n             evaluation_func: Callable = None,\n             eval_per_steps: int = 1000,\n             device=None):\n\n    assert train_dataloader is not None\n\n    model.train()\n    if teacher_model is not None:\n        teacher_model.eval()\n    current_step = 0\n    best_result = 0\n\n    total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3\n    total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)\n\n    print(f'Training {total_epochs} epochs, {total_steps} steps...')\n\n    for current_epoch in range(total_epochs):\n        for batch in train_dataloader:\n            if current_step >= total_steps:\n                return\n            batch.to(device)\n            outputs = model(**batch)\n            loss = outputs.loss\n\n            if distillation:\n                assert teacher_model is not None\n                with torch.no_grad():\n                    teacher_outputs = teacher_model(**batch)\n                distil_loss = distil_func(outputs, teacher_outputs)\n                loss = 0.1 * loss + 0.9 * distil_loss\n\n            loss = criterion(loss, None)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            # per step schedule\n            if lr_scheduler:\n                lr_scheduler.step()\n\n            current_step += 1\n\n            if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:\n                result = evaluation_func(model) if evaluation_func else None\n                with (log_path).open('a+') as f:\n                    msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)\n                    f.write(msg)\n                # if it's the best model, save it.\n                if save_best_model and (result is None or best_result < result['default']):\n                    assert save_path is not None\n                    torch.save(model.state_dict(), save_path)\n                    best_result = None if result is None else result['default']\n\n\ndef distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):\n    encoder_hidden_state_loss = []\n    for i, idx in enumerate(encoder_layer_idxs[:-1]):\n        encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))\n    logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)\n\n    distil_loss = 0\n    for loss in encoder_hidden_state_loss:\n        distil_loss += loss\n    distil_loss += logits_loss\n    return distil_loss\n\n\ndef evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):\n    assert validation_dataloaders is not None\n    training = model.training\n    model.eval()\n\n    is_regression = task_name == 'stsb'\n    metric = load_metric('glue', task_name)\n\n    result = {}\n    default_result = 0\n    for val_name, validation_dataloader in validation_dataloaders.items():\n        for batch in validation_dataloader:\n            batch.to(device)\n            outputs = model(**batch)\n            predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()\n            metric.add_batch(\n                predictions=predictions,\n                references=batch['labels'],\n            )\n        result[val_name] = metric.compute()\n        default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))\n    result['default'] = default_result / len(result)\n\n    model.train(training)\n    return result\n\n\nevaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)\n\n\ndef fake_criterion(loss, _):\n    return loss"
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Prepare pre-trained model and finetuning on downstream task.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
J-shang's avatar
J-shang committed
101
        "from torch.optim import Adam\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom transformers import BertForSequenceClassification\n\n\ndef create_pretrained_model():\n    is_regression = task_name == 'stsb'\n    num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)\n    model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)\n    model.bert.config.output_hidden_states = True\n    return model\n\n\ndef create_finetuned_model():\n    finetuned_model = create_pretrained_model()\n    finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'\n\n    if finetuned_model_state_path.exists():\n        finetuned_model.load_state_dict(torch.load(finetuned_model_state_path, map_location='cpu'))\n        finetuned_model.to(device)\n    elif dev_mode:\n        pass\n    else:\n        steps_per_epoch = len(train_dataloader)\n        training_epochs = 3\n        optimizer = Adam(finetuned_model.parameters(), lr=3e-5, eps=1e-8)\n\n        def lr_lambda(current_step: int):\n            return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))\n\n        lr_scheduler = LambdaLR(optimizer, lr_lambda)\n        training(finetuned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,\n                 max_epochs=training_epochs, train_dataloader=train_dataloader, log_path=log_dir / 'finetuning_on_downstream.log',\n                 save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func, device=device)\n    return finetuned_model\n\n\nfinetuned_model = create_finetuned_model()"
102
103
104
105
106
107
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
J-shang's avatar
J-shang committed
108
        "### Pruning\nAccording to experience, it is easier to achieve good results by pruning the attention part and the FFN part in stages.\nOf course, pruning together can also achieve the similar effect, but more parameter adjustment attempts are required.\nSo in this section, we do pruning in stages.\n\nFirst, we prune the attention layer with MovementPruner.\n\n"
109
110
111
112
113
114
115
116
117
118
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
J-shang's avatar
J-shang committed
119
        "steps_per_epoch = len(train_dataloader)\n\n# Set training steps/epochs for pruning.\n\nif not dev_mode:\n    total_epochs = 4\n    total_steps = total_epochs * steps_per_epoch\n    warmup_steps = 1 * steps_per_epoch\n    cooldown_steps = 1 * steps_per_epoch\nelse:\n    total_epochs = 1\n    total_steps = 3\n    warmup_steps = 1\n    cooldown_steps = 1\n\n# Initialize evaluator used by MovementPruner.\n\nimport nni\nfrom nni.algorithms.compression.v2.pytorch import TorchEvaluator\n\nmovement_training = functools.partial(training, train_dataloader=train_dataloader,\n                                      log_path=log_dir / 'movement_pruning.log',\n                                      evaluation_func=evaluation_func, device=device)\ntraced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)\n\ndef lr_lambda(current_step: int):\n    if current_step < warmup_steps:\n        return float(current_step) / warmup_steps\n    return max(0.0, float(total_steps - current_step) / float(total_steps - warmup_steps))\n\ntraced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)\nevaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)\n\n# Apply block-soft-movement pruning on attention layers.\n# Note that block sparse is introduced by `sparse_granularity='auto'`, and only support `bert`, `bart`, `t5` right now.\n\nfrom nni.compression.pytorch.pruning import MovementPruner\n\nconfig_list = [{\n    'op_types': ['Linear'],\n    'op_partial_names': ['bert.encoder.layer.{}.attention'.format(i) for i in range(layers_num)],\n    'sparsity': 0.1\n}]\n\npruner = MovementPruner(model=finetuned_model,\n                        config_list=config_list,\n                        evaluator=evaluator,\n                        training_epochs=total_epochs,\n                        training_steps=total_steps,\n                        warm_up_step=warmup_steps,\n                        cool_down_beginning_step=total_steps - cooldown_steps,\n                        regular_scale=10,\n                        movement_mode='soft',\n                        sparse_granularity='auto')\n_, attention_masks = pruner.compress()\npruner.show_pruned_weights()\n\ntorch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')"
120
121
122
123
124
125
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
J-shang's avatar
J-shang committed
126
        "Load a new finetuned model to do speedup, you can think of this as using the finetuned state to initialize the pruned model weights.\nNote that nni speedup don't support replacing attention module, so here we manully replace the attention module.\n\nIf the head is entire masked, physically prune it and create config_list for FFN pruning.\n\n"
127
128
129
130
131
132
133
134
135
136
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
J-shang's avatar
J-shang committed
137
        "attention_pruned_model = create_finetuned_model().to(device)\nattention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')\n\nffn_config_list = []\nlayer_remained_idxs = []\nmodule_list = []\nfor i in range(0, layers_num):\n    prefix = f'bert.encoder.layer.{i}.'\n    value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']\n    head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)\n    head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()\n    print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')\n    if len(head_idxs) != heads_num:\n        attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)\n        module_list.append(attention_pruned_model.bert.encoder.layer[i])\n        # The final ffn weight remaining ratio is the half of the attention weight remaining ratio.\n        # This is just an empirical configuration, you can use any other method to determine this sparsity.\n        sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5\n        # here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.\n        sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)\n        ffn_config_list.append({\n            'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],\n            'sparsity': sparsity_per_iter\n        })\n        layer_remained_idxs.append(i)\n\nattention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)\ndistil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)"
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Retrain the attention pruned model with distillation.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
J-shang's avatar
J-shang committed
155
        "if not dev_mode:\n    total_epochs = 5\n    total_steps = None\n    distillation = True\nelse:\n    total_epochs = 1\n    total_steps = 1\n    distillation = False\n\nteacher_model = create_finetuned_model()\noptimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)\n\ndef lr_lambda(current_step: int):\n    return max(0.0, float(total_epochs * steps_per_epoch - current_step) / float(total_epochs * steps_per_epoch))\n\nlr_scheduler = LambdaLR(optimizer, lr_lambda)\nat_model_save_path = log_dir / 'attention_pruned_model_state.pth'\ntraining(attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=total_epochs,\n         max_steps=total_steps, train_dataloader=train_dataloader, distillation=distillation, teacher_model=teacher_model,\n         distil_func=distil_func, log_path=log_dir / 'retraining.log', save_best_model=True, save_path=at_model_save_path,\n         evaluation_func=evaluation_func, device=device)\n\nif not dev_mode:\n    attention_pruned_model.load_state_dict(torch.load(at_model_save_path))"
156
157
158
159
160
161
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
J-shang's avatar
J-shang committed
162
        "Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations.\nFinetuning 3000 steps after each pruning iteration, then finetuning 2 epochs after pruning finished.\n\nNNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.\n\n"
163
164
165
166
167
168
169
170
171
172
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
J-shang's avatar
J-shang committed
173
        "if not dev_mode:\n    total_epochs = 7\n    total_steps = None\n    taylor_pruner_steps = 1000\n    steps_per_iteration = 3000\n    total_pruning_steps = 36000\n    distillation = True\nelse:\n    total_epochs = 1\n    total_steps = 6\n    taylor_pruner_steps = 2\n    steps_per_iteration = 2\n    total_pruning_steps = 4\n    distillation = False\n\nfrom nni.compression.pytorch.pruning import TaylorFOWeightPruner\nfrom nni.compression.pytorch.speedup import ModelSpeedup\n\ndistil_training = functools.partial(training, train_dataloader=train_dataloader, distillation=distillation,\n                                    teacher_model=teacher_model, distil_func=distil_func, device=device)\ntraced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)\nevaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)\n\ncurrent_step = 0\nbest_result = 0\ninit_lr = 3e-5\n\ndummy_input = torch.rand(8, 128, 768).to(device)\n\nattention_pruned_model.train()\nfor current_epoch in range(total_epochs):\n    for batch in train_dataloader:\n        if total_steps and current_step >= total_steps:\n            break\n        # pruning with TaylorFOWeightPruner & reinitialize optimizer\n        if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:\n            check_point = attention_pruned_model.state_dict()\n            pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)\n            _, ffn_masks = pruner.compress()\n            renamed_ffn_masks = {}\n            # rename the masks keys, because we only speedup the bert.encoder\n            for model_name, targets_mask in ffn_masks.items():\n                renamed_ffn_masks[model_name.split('bert.encoder.')[1]] = targets_mask\n            pruner._unwrap_model()\n            attention_pruned_model.load_state_dict(check_point)\n            ModelSpeedup(attention_pruned_model.bert.encoder, dummy_input, renamed_ffn_masks).speedup_model()\n            optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)\n\n        batch.to(device)\n        # manually schedule lr\n        for params_group in optimizer.param_groups:\n            params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr\n\n        outputs = attention_pruned_model(**batch)\n        loss = outputs.loss\n\n        # distillation\n        if distillation:\n            assert teacher_model is not None\n            with torch.no_grad():\n                teacher_outputs = teacher_model(**batch)\n            distil_loss = distil_func(outputs, teacher_outputs)\n            loss = 0.1 * loss + 0.9 * distil_loss\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        current_step += 1\n\n        if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:\n            result = evaluation_func(attention_pruned_model)\n            with (log_dir / 'ffn_pruning.log').open('a+') as f:\n                msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())),\n                                                            current_epoch, current_step, result)\n                f.write(msg)\n            if current_step >= total_pruning_steps and best_result < result['default']:\n                torch.save(attention_pruned_model, log_dir / 'best_model.pth')\n                best_result = result['default']"
174
175
176
177
178
179
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
J-shang's avatar
J-shang committed
180
        "## Result\nThe speedup is test on the entire validation dataset with batch size 32 on A100.\nWe test under two pytorch version and found the latency varying widely.\n\nSetting 1: pytorch 1.12.1\n\nSetting 2: pytorch 1.10.0\n\n.. list-table:: Prune Bert-base-uncased on MNLI\n    :header-rows: 1\n    :widths: auto\n\n    * - Attention Pruning Method\n      - FFN Pruning Method\n      - Total Sparsity\n      - Accuracy\n      - Acc. Drop\n      - Speedup (S1)\n      - Speedup (S2)\n    * -\n      -\n      - 0%\n      - 84.73 / 84.63\n      - +0.0 / +0.0\n      - 12.56s (x1.00)\n      - 4.05s (x1.00)\n    * - `movement-pruner` (soft, sparsity=0.1, regular_scale=5)\n      - `taylor-fo-weight-pruner`\n      - 51.39%\n      - 84.25 / 84.96\n      - -0.48 / +0.33\n      - 6.85s (x1.83)\n      - 2.7s (x1.50)\n    * - `movement-pruner` (soft, sparsity=0.1, regular_scale=10)\n      - `taylor-fo-weight-pruner`\n      - 66.67%\n      - 83.98 / 83.75\n      - -0.75 / -0.88\n      - 4.73s (x2.66)\n      - 2.16s (x1.86)\n    * - `movement-pruner` (soft, sparsity=0.1, regular_scale=20)\n      - `taylor-fo-weight-pruner`\n      - 77.78%\n      - 83.02 / 83.06\n      - -1.71 / -1.57\n      - 3.35s (x3.75)\n      - 1.72s (x2.35)\n    * - `movement-pruner` (soft, sparsity=0.1, regular_scale=30)\n      - `taylor-fo-weight-pruner`\n      - 87.04%\n      - 81.24 / 80.99\n      - -3.49 / -3.64\n      - 2.19s (x5.74)\n      - 1.31s (x3.09)\n\n"
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
J-shang's avatar
J-shang committed
200
      "version": "3.8.13"
201
202
203
204
205
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}