.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/pruning_bert_glue.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_pruning_bert_glue.py: Pruning Transformer with NNI ============================ Workable Pruning Process ------------------------ Here we show an effective transformer pruning process that NNI team has tried, and users can use NNI to discover better processes. The entire pruning process can be divided into the following steps: 1. Finetune the pre-trained model on the downstream task. From our experience, the final performance of pruning on the finetuned model is better than pruning directly on the pre-trained model. At the same time, the finetuned model obtained in this step will also be used as the teacher model for the following distillation training. 2. Pruning the attention layer at first. Here we apply block-sparse on attention layer weight, and directly prune the head (condense the weight) if the head was fully masked. If the head was partially masked, we will not prune it and recover its weight. 3. Retrain the head-pruned model with distillation. Recover the model precision before pruning FFN layer. 4. Pruning the FFN layer. Here we apply the output channels pruning on the 1st FFN layer, and the 2nd FFN layer input channels will be pruned due to the pruning of 1st layer output channels. 5. Retrain the final pruned model with distillation. During the process of pruning transformer, we gained some of the following experiences: * We using :ref:`movement-pruner` in step 2 and :ref:`taylor-fo-weight-pruner` in step 4. :ref:`movement-pruner` has good performance on attention layers, and :ref:`taylor-fo-weight-pruner` method has good performance on FFN layers. These two pruners are all some kinds of gradient-based pruning algorithms, we also try weight-based pruning algorithms like :ref:`l1-norm-pruner`, but it doesn't seem to work well in this scenario. * Distillation is a good way to recover model precision. In terms of results, usually 1~2% improvement in accuracy can be achieved when we prune bert on mnli task. * It is necessary to gradually increase the sparsity rather than reaching a very high sparsity all at once. Experiment ---------- Preparation ^^^^^^^^^^^ Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents. The complete pruning process takes about 8 hours on one A100. .. GENERATED FROM PYTHON SOURCE LINES 41-44 .. code-block:: default dev_mode = True .. GENERATED FROM PYTHON SOURCE LINES 45-46 Some basic setting. .. GENERATED FROM PYTHON SOURCE LINES 46-72 .. code-block:: default from pathlib import Path from typing import Callable pretrained_model_name_or_path = 'bert-base-uncased' task_name = 'mnli' experiment_id = 'pruning_bert' # heads_num and layers_num should align with pretrained_model_name_or_path heads_num = 12 layers_num = 12 # used to save the experiment log log_dir = Path(f'./pruning_log/{pretrained_model_name_or_path}/{task_name}/{experiment_id}') log_dir.mkdir(parents=True, exist_ok=True) # used to save the finetuned model and share between different experiemnts with same pretrained_model_name_or_path and task_name model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}') model_dir.mkdir(parents=True, exist_ok=True) from transformers import set_seed set_seed(1024) import torch device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') .. GENERATED FROM PYTHON SOURCE LINES 73-75 The function used to create dataloaders, note that 'mnli' has two evaluation dataset. If teacher_model is set, will run all dataset on teacher model to get the 'teacher_logits' for distillation. .. GENERATED FROM PYTHON SOURCE LINES 75-157 .. code-block:: default from torch.utils.data import DataLoader from datasets import load_dataset from transformers import BertTokenizerFast, DataCollatorWithPadding task_to_keys = { 'cola': ('sentence', None), 'mnli': ('premise', 'hypothesis'), 'mrpc': ('sentence1', 'sentence2'), 'qnli': ('question', 'sentence'), 'qqp': ('question1', 'question2'), 'rte': ('sentence1', 'sentence2'), 'sst2': ('sentence', None), 'stsb': ('sentence1', 'sentence2'), 'wnli': ('sentence1', 'sentence2'), } def prepare_data(cache_dir='./data', train_batch_size=32, eval_batch_size=32, teacher_model: torch.nn.Module = None): tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path) sentence1_key, sentence2_key = task_to_keys[task_name] data_collator = DataCollatorWithPadding(tokenizer) # used to preprocess the raw data def preprocess_function(examples): # Tokenize the texts args = ( (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) ) result = tokenizer(*args, padding=False, max_length=128, truncation=True) if 'label' in examples: # In all cases, rename the column to labels because the model will expect that. result['labels'] = examples['label'] return result raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir) for key in list(raw_datasets.keys()): if 'test' in key: raw_datasets.pop(key) processed_datasets = raw_datasets.map(preprocess_function, batched=True, remove_columns=raw_datasets['train'].column_names) # if has teacher model, add 'teacher_logits' to datasets who has 'labels'. # 'teacher_logits' is used for distillation and avoid the double counting. if teacher_model: teacher_model_training = teacher_model.training teacher_model.eval() model_device = next(teacher_model.parameters()).device def add_teacher_logits(examples): result = {k: v for k, v in examples.items()} samples = data_collator(result).to(model_device) if 'labels' in samples: with torch.no_grad(): logits = teacher_model(**samples).logits.tolist() result['teacher_logits'] = logits return result processed_datasets = processed_datasets.map(add_teacher_logits, batched=True, batch_size=train_batch_size) teacher_model.train(teacher_model_training) train_dataset = processed_datasets['train'] validation_dataset = processed_datasets['validation_matched' if task_name == 'mnli' else 'validation'] validation_dataset2 = processed_datasets['validation_mismatched'] if task_name == 'mnli' else None train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size) validation_dataloader = DataLoader(validation_dataset, collate_fn=data_collator, batch_size=eval_batch_size) validation_dataloader2 = DataLoader(validation_dataset2, collate_fn=data_collator, batch_size=eval_batch_size) if task_name == 'mnli' else None return train_dataloader, validation_dataloader, validation_dataloader2 .. GENERATED FROM PYTHON SOURCE LINES 158-159 Training function & evaluation function. .. GENERATED FROM PYTHON SOURCE LINES 159-258 .. code-block:: default import time import torch.nn.functional as F from datasets import load_metric def training(train_dataloader: DataLoader, model: torch.nn.Module, optimizer: torch.optim.Optimizer, criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, max_steps: int = None, max_epochs: int = None, save_best_model: bool = False, save_path: str = None, log_path: str = Path(log_dir) / 'training.log', distillation: bool = False, evaluation_func=None): model.train() current_step = 0 best_result = 0 for current_epoch in range(max_epochs if max_epochs else 1): for batch in train_dataloader: batch.to(device) teacher_logits = batch.pop('teacher_logits', None) optimizer.zero_grad() outputs = model(**batch) loss = outputs.loss if distillation: assert teacher_logits is not None distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1), F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2) loss = 0.1 * loss + 0.9 * distil_loss loss = criterion(loss, None) loss.backward() optimizer.step() if lr_scheduler: lr_scheduler.step() current_step += 1 # evaluation for every 1000 steps if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0: result = evaluation_func(model) if evaluation_func else None with (log_path).open('a+') as f: msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result) f.write(msg) # if it's the best model, save it. if save_best_model and best_result < result['default']: assert save_path is not None torch.save(model.state_dict(), save_path) best_result = result['default'] if max_steps and current_step >= max_steps: return def evaluation(validation_dataloader: DataLoader, validation_dataloader2: DataLoader, model: torch.nn.Module): training = model.training model.eval() is_regression = task_name == 'stsb' metric = load_metric('glue', task_name) for batch in validation_dataloader: batch.pop('teacher_logits', None) batch.to(device) outputs = model(**batch) predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() metric.add_batch( predictions=predictions, references=batch['labels'], ) result = metric.compute() if validation_dataloader2: for batch in validation_dataloader2: batch.pop('teacher_logits', None) batch.to(device) outputs = model(**batch) predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() metric.add_batch( predictions=predictions, references=batch['labels'], ) result = {'matched': result, 'mismatched': metric.compute()} result['default'] = (result['matched']['accuracy'] + result['mismatched']['accuracy']) / 2 else: result['default'] = result.get('f1', result.get('accuracy', None)) model.train(training) return result # using huggingface native loss def fake_criterion(outputs, targets): return outputs .. GENERATED FROM PYTHON SOURCE LINES 259-260 Prepare pre-trained model and finetuning on downstream task. .. GENERATED FROM PYTHON SOURCE LINES 260-299 .. code-block:: default import functools from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR from transformers import BertForSequenceClassification def create_pretrained_model(): is_regression = task_name == 'stsb' num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2) return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels) def create_finetuned_model(): pretrained_model = create_pretrained_model().to(device) train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data() evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2) steps_per_epoch = len(train_dataloader) training_epochs = 3 finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth' if finetuned_model_state_path.exists(): pretrained_model.load_state_dict(torch.load(finetuned_model_state_path)) elif dev_mode: pass else: optimizer = Adam(pretrained_model.parameters(), lr=3e-5, eps=1e-8) def lr_lambda(current_step: int): return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch)) lr_scheduler = LambdaLR(optimizer, lr_lambda) training(train_dataloader, pretrained_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=training_epochs, save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func) return pretrained_model finetuned_model = create_finetuned_model() .. rst-class:: sphx-glr-script-out .. code-block:: none Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight'] - This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. Reusing dataset glue (./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad) 0%| | 0/5 [00:00= total_steps: break # pruning 12 times if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps: check_point = attention_pruned_model.state_dict() pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps) _, ffn_masks = pruner.compress() renamed_ffn_masks = {} # rename the masks keys, because we only speedup the bert.encoder for model_name, targets_mask in ffn_masks.items(): renamed_ffn_masks[model_name.split('bert.encoder.')[1]] = targets_mask pruner._unwrap_model() attention_pruned_model.load_state_dict(check_point) ModelSpeedup(attention_pruned_model.bert.encoder, dummy_input, renamed_ffn_masks).speedup_model() optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr) batch.to(device) teacher_logits = batch.pop('teacher_logits', None) optimizer.zero_grad() # manually schedule lr for params_group in optimizer.param_groups: params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr outputs = attention_pruned_model(**batch) loss = outputs.loss # distillation if teacher_logits is not None: distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1), F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2) loss = 0.1 * loss + 0.9 * distil_loss loss.backward() optimizer.step() current_step += 1 if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0: result = evaluation_func(attention_pruned_model) with (log_dir / 'ffn_pruning.log').open('a+') as f: msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result) f.write(msg) if current_step >= total_pruning_steps and best_result < result['default']: torch.save(attention_pruned_model, log_dir / 'best_model.pth') best_result = result['default'] .. rst-class:: sphx-glr-script-out .. code-block:: none Did not bind any model, no need to unbind model. no multi-dimension masks found. /home/nishang/anaconda3/envs/nni-dev/lib/python3.7/site-packages/torch/_tensor.py:1083: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:477.) return self._grad Did not bind any model, no need to unbind model. no multi-dimension masks found. .. GENERATED FROM PYTHON SOURCE LINES 509-564 Result ------ The speedup is test on the entire validation dataset with batch size 32 on A100. We test under two pytorch version and found the latency varying widely. Setting 1: pytorch 1.12.1 Setting 2: pytorch 1.10.0 .. list-table:: Prune Bert-base-uncased on MNLI :header-rows: 1 :widths: auto * - Attention Pruning Method - FFN Pruning Method - Total Sparsity - Accuracy - Acc. Drop - Speedup (S1) - Speedup (S2) * - - - 0% - 84.73 / 84.63 - +0.0 / +0.0 - 12.56s (x1.00) - 4.05s (x1.00) * - :ref:`movement-pruner` (soft, th=0.1, lambda=5) - :ref:`taylor-fo-weight-pruner` - 51.39% - 84.25 / 84.96 - -0.48 / +0.33 - 6.85s (x1.83) - 2.7s (x1.50) * - :ref:`movement-pruner` (soft, th=0.1, lambda=10) - :ref:`taylor-fo-weight-pruner` - 66.67% - 83.98 / 83.75 - -0.75 / -0.88 - 4.73s (x2.66) - 2.16s (x1.86) * - :ref:`movement-pruner` (soft, th=0.1, lambda=20) - :ref:`taylor-fo-weight-pruner` - 77.78% - 83.02 / 83.06 - -1.71 / -1.57 - 3.35s (x3.75) - 1.72s (x2.35) * - :ref:`movement-pruner` (soft, th=0.1, lambda=30) - :ref:`taylor-fo-weight-pruner` - 87.04% - 81.24 / 80.99 - -3.49 / -3.64 - 2.19s (x5.74) - 1.31s (x3.09) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 27.206 seconds) .. _sphx_glr_download_tutorials_pruning_bert_glue.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: pruning_bert_glue.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: pruning_bert_glue.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_