Unverified Commit 808b0655 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Bugbash] update doc (#5075)

parent bb8114a4
docs/img/compression_pipeline.png

84.2 KB | W: | H:

docs/img/compression_pipeline.png

38.3 KB | W: | H:

docs/img/compression_pipeline.png
docs/img/compression_pipeline.png
docs/img/compression_pipeline.png
docs/img/compression_pipeline.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -170,13 +170,13 @@ Tutorials ...@@ -170,13 +170,13 @@ Tutorials
.. only:: html .. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png .. image:: /tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
:alt: Pruning Transformer with NNI :alt: Pruning Bert on Task MNLI
:ref:`sphx_glr_tutorials_pruning_bert_glue.py` :ref:`sphx_glr_tutorials_pruning_bert_glue.py`
.. raw:: html .. raw:: html
<div class="sphx-glr-thumbnail-title">Pruning Transformer with NNI</div> <div class="sphx-glr-thumbnail-title">Pruning Bert on Task MNLI</div>
</div> </div>
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"\n# Pruning Transformer with NNI\n\n## Workable Pruning Process\n\nHere we show an effective transformer pruning process that NNI team has tried, and users can use NNI to discover better processes.\n\nThe entire pruning process can be divided into the following steps:\n\n1. Finetune the pre-trained model on the downstream task. From our experience,\n the final performance of pruning on the finetuned model is better than pruning directly on the pre-trained model.\n At the same time, the finetuned model obtained in this step will also be used as the teacher model for the following\n distillation training.\n2. Pruning the attention layer at first. Here we apply block-sparse on attention layer weight,\n and directly prune the head (condense the weight) if the head was fully masked.\n If the head was partially masked, we will not prune it and recover its weight.\n3. Retrain the head-pruned model with distillation. Recover the model precision before pruning FFN layer.\n4. Pruning the FFN layer. Here we apply the output channels pruning on the 1st FFN layer,\n and the 2nd FFN layer input channels will be pruned due to the pruning of 1st layer output channels.\n5. Retrain the final pruned model with distillation.\n\nDuring the process of pruning transformer, we gained some of the following experiences:\n\n* We using `movement-pruner` in step 2 and `taylor-fo-weight-pruner` in step 4. `movement-pruner` has good performance on attention layers,\n and `taylor-fo-weight-pruner` method has good performance on FFN layers. These two pruners are all some kinds of gradient-based pruning algorithms,\n we also try weight-based pruning algorithms like `l1-norm-pruner`, but it doesn't seem to work well in this scenario.\n* Distillation is a good way to recover model precision. In terms of results, usually 1~2% improvement in accuracy can be achieved when we prune bert on mnli task.\n* It is necessary to gradually increase the sparsity rather than reaching a very high sparsity all at once.\n\n## Experiment\n\n### Preparation\nPlease set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.\n\nThe complete pruning process takes about 8 hours on one A100.\n" "\n# Pruning Bert on Task MNLI\n\n## Workable Pruning Process\n\nHere we show an effective transformer pruning process that NNI team has tried, and users can use NNI to discover better processes.\n\nThe entire pruning process can be divided into the following steps:\n\n1. Finetune the pre-trained model on the downstream task. From our experience,\n the final performance of pruning on the finetuned model is better than pruning directly on the pre-trained model.\n At the same time, the finetuned model obtained in this step will also be used as the teacher model for the following\n distillation training.\n2. Pruning the attention layer at first. Here we apply block-sparse on attention layer weight,\n and directly prune the head (condense the weight) if the head was fully masked.\n If the head was partially masked, we will not prune it and recover its weight.\n3. Retrain the head-pruned model with distillation. Recover the model precision before pruning FFN layer.\n4. Pruning the FFN layer. Here we apply the output channels pruning on the 1st FFN layer,\n and the 2nd FFN layer input channels will be pruned due to the pruning of 1st layer output channels.\n5. Retrain the final pruned model with distillation.\n\nDuring the process of pruning transformer, we gained some of the following experiences:\n\n* We using `movement-pruner` in step 2 and `taylor-fo-weight-pruner` in step 4. `movement-pruner` has good performance on attention layers,\n and `taylor-fo-weight-pruner` method has good performance on FFN layers. These two pruners are all some kinds of gradient-based pruning algorithms,\n we also try weight-based pruning algorithms like `l1-norm-pruner`, but it doesn't seem to work well in this scenario.\n* Distillation is a good way to recover model precision. In terms of results, usually 1~2% improvement in accuracy can be achieved when we prune bert on mnli task.\n* It is necessary to gradually increase the sparsity rather than reaching a very high sparsity all at once.\n\n## Experiment\n\nThe complete pruning process will take about 8 hours on one A100.\n\n### Preparation\n\nThis section is mainly to get a finetuned model on the downstream task.\nIf you are familiar with how to finetune Bert on GLUE dataset, you can skip this section.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.</p></div>\n"
] ]
}, },
{ {
...@@ -44,14 +44,14 @@ ...@@ -44,14 +44,14 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from pathlib import Path\nfrom typing import Callable\n\npretrained_model_name_or_path = 'bert-base-uncased'\ntask_name = 'mnli'\nexperiment_id = 'pruning_bert'\n\n# heads_num and layers_num should align with pretrained_model_name_or_path\nheads_num = 12\nlayers_num = 12\n\n# used to save the experiment log\nlog_dir = Path(f'./pruning_log/{pretrained_model_name_or_path}/{task_name}/{experiment_id}')\nlog_dir.mkdir(parents=True, exist_ok=True)\n\n# used to save the finetuned model and share between different experiemnts with same pretrained_model_name_or_path and task_name\nmodel_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')\nmodel_dir.mkdir(parents=True, exist_ok=True)\n\nfrom transformers import set_seed\nset_seed(1024)\n\nimport torch\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" "from pathlib import Path\nfrom typing import Callable, Dict\n\npretrained_model_name_or_path = 'bert-base-uncased'\ntask_name = 'mnli'\nexperiment_id = 'pruning_bert_mnli'\n\n# heads_num and layers_num should align with pretrained_model_name_or_path\nheads_num = 12\nlayers_num = 12\n\n# used to save the experiment log\nlog_dir = Path(f'./pruning_log/{pretrained_model_name_or_path}/{task_name}/{experiment_id}')\nlog_dir.mkdir(parents=True, exist_ok=True)\n\n# used to save the finetuned model and share between different experiemnts with same pretrained_model_name_or_path and task_name\nmodel_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')\nmodel_dir.mkdir(parents=True, exist_ok=True)\n\n# used to save GLUE data\ndata_dir = Path(f'./data')\ndata_dir.mkdir(parents=True, exist_ok=True)\n\n# set seed\nfrom transformers import set_seed\nset_seed(1024)\n\nimport torch\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"The function used to create dataloaders, note that 'mnli' has two evaluation dataset.\nIf teacher_model is set, will run all dataset on teacher model to get the 'teacher_logits' for distillation.\n\n" "Create dataloaders.\n\n"
] ]
}, },
{ {
...@@ -62,7 +62,7 @@ ...@@ -62,7 +62,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from torch.utils.data import DataLoader\n\nfrom datasets import load_dataset\nfrom transformers import BertTokenizerFast, DataCollatorWithPadding\n\ntask_to_keys = {\n 'cola': ('sentence', None),\n 'mnli': ('premise', 'hypothesis'),\n 'mrpc': ('sentence1', 'sentence2'),\n 'qnli': ('question', 'sentence'),\n 'qqp': ('question1', 'question2'),\n 'rte': ('sentence1', 'sentence2'),\n 'sst2': ('sentence', None),\n 'stsb': ('sentence1', 'sentence2'),\n 'wnli': ('sentence1', 'sentence2'),\n}\n\ndef prepare_data(cache_dir='./data', train_batch_size=32, eval_batch_size=32,\n teacher_model: torch.nn.Module = None):\n tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)\n sentence1_key, sentence2_key = task_to_keys[task_name]\n data_collator = DataCollatorWithPadding(tokenizer)\n\n # used to preprocess the raw data\n def preprocess_function(examples):\n # Tokenize the texts\n args = (\n (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])\n )\n result = tokenizer(*args, padding=False, max_length=128, truncation=True)\n\n if 'label' in examples:\n # In all cases, rename the column to labels because the model will expect that.\n result['labels'] = examples['label']\n return result\n\n raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)\n for key in list(raw_datasets.keys()):\n if 'test' in key:\n raw_datasets.pop(key)\n\n processed_datasets = raw_datasets.map(preprocess_function, batched=True,\n remove_columns=raw_datasets['train'].column_names)\n\n # if has teacher model, add 'teacher_logits' to datasets who has 'labels'.\n # 'teacher_logits' is used for distillation and avoid the double counting.\n if teacher_model:\n teacher_model_training = teacher_model.training\n teacher_model.eval()\n model_device = next(teacher_model.parameters()).device\n\n def add_teacher_logits(examples):\n result = {k: v for k, v in examples.items()}\n samples = data_collator(result).to(model_device)\n if 'labels' in samples:\n with torch.no_grad():\n logits = teacher_model(**samples).logits.tolist()\n result['teacher_logits'] = logits\n return result\n\n processed_datasets = processed_datasets.map(add_teacher_logits, batched=True,\n batch_size=train_batch_size)\n teacher_model.train(teacher_model_training)\n\n train_dataset = processed_datasets['train']\n validation_dataset = processed_datasets['validation_matched' if task_name == 'mnli' else 'validation']\n validation_dataset2 = processed_datasets['validation_mismatched'] if task_name == 'mnli' else None\n\n train_dataloader = DataLoader(train_dataset,\n shuffle=True,\n collate_fn=data_collator,\n batch_size=train_batch_size)\n validation_dataloader = DataLoader(validation_dataset,\n collate_fn=data_collator,\n batch_size=eval_batch_size)\n validation_dataloader2 = DataLoader(validation_dataset2,\n collate_fn=data_collator,\n batch_size=eval_batch_size) if task_name == 'mnli' else None\n\n return train_dataloader, validation_dataloader, validation_dataloader2" "from torch.utils.data import DataLoader\n\nfrom datasets import load_dataset\nfrom transformers import BertTokenizerFast, DataCollatorWithPadding\n\ntask_to_keys = {\n 'cola': ('sentence', None),\n 'mnli': ('premise', 'hypothesis'),\n 'mrpc': ('sentence1', 'sentence2'),\n 'qnli': ('question', 'sentence'),\n 'qqp': ('question1', 'question2'),\n 'rte': ('sentence1', 'sentence2'),\n 'sst2': ('sentence', None),\n 'stsb': ('sentence1', 'sentence2'),\n 'wnli': ('sentence1', 'sentence2'),\n}\n\ndef prepare_dataloaders(cache_dir=data_dir, train_batch_size=32, eval_batch_size=32):\n tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)\n sentence1_key, sentence2_key = task_to_keys[task_name]\n data_collator = DataCollatorWithPadding(tokenizer)\n\n # used to preprocess the raw data\n def preprocess_function(examples):\n # Tokenize the texts\n args = (\n (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])\n )\n result = tokenizer(*args, padding=False, max_length=128, truncation=True)\n\n if 'label' in examples:\n # In all cases, rename the column to labels because the model will expect that.\n result['labels'] = examples['label']\n return result\n\n raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)\n for key in list(raw_datasets.keys()):\n if 'test' in key:\n raw_datasets.pop(key)\n\n processed_datasets = raw_datasets.map(preprocess_function, batched=True,\n remove_columns=raw_datasets['train'].column_names)\n\n train_dataset = processed_datasets['train']\n if task_name == 'mnli':\n validation_datasets = {\n 'validation_matched': processed_datasets['validation_matched'],\n 'validation_mismatched': processed_datasets['validation_mismatched']\n }\n else:\n validation_datasets = {\n 'validation': processed_datasets['validation']\n }\n\n train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)\n validation_dataloaders = {\n val_name: DataLoader(val_dataset, collate_fn=data_collator, batch_size=eval_batch_size) \\\n for val_name, val_dataset in validation_datasets.items()\n }\n\n return train_dataloader, validation_dataloaders\n\n\ntrain_dataloader, validation_dataloaders = prepare_dataloaders()"
] ]
}, },
{ {
...@@ -80,7 +80,7 @@ ...@@ -80,7 +80,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import time\nimport torch.nn.functional as F\nfrom datasets import load_metric\n\ndef training(train_dataloader: DataLoader,\n model: torch.nn.Module,\n optimizer: torch.optim.Optimizer,\n criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,\n max_steps: int = None, max_epochs: int = None,\n save_best_model: bool = False, save_path: str = None,\n log_path: str = Path(log_dir) / 'training.log',\n distillation: bool = False,\n evaluation_func=None):\n model.train()\n current_step = 0\n best_result = 0\n\n for current_epoch in range(max_epochs if max_epochs else 1):\n for batch in train_dataloader:\n batch.to(device)\n teacher_logits = batch.pop('teacher_logits', None)\n optimizer.zero_grad()\n outputs = model(**batch)\n loss = outputs.loss\n\n if distillation:\n assert teacher_logits is not None\n distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1),\n F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)\n loss = 0.1 * loss + 0.9 * distil_loss\n\n loss = criterion(loss, None)\n loss.backward()\n optimizer.step()\n\n if lr_scheduler:\n lr_scheduler.step()\n\n current_step += 1\n\n # evaluation for every 1000 steps\n if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:\n result = evaluation_func(model) if evaluation_func else None\n with (log_path).open('a+') as f:\n msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)\n f.write(msg)\n # if it's the best model, save it.\n if save_best_model and best_result < result['default']:\n assert save_path is not None\n torch.save(model.state_dict(), save_path)\n best_result = result['default']\n\n if max_steps and current_step >= max_steps:\n return\n\ndef evaluation(validation_dataloader: DataLoader,\n validation_dataloader2: DataLoader,\n model: torch.nn.Module):\n training = model.training\n model.eval()\n is_regression = task_name == 'stsb'\n metric = load_metric('glue', task_name)\n\n for batch in validation_dataloader:\n batch.pop('teacher_logits', None)\n batch.to(device)\n outputs = model(**batch)\n predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()\n metric.add_batch(\n predictions=predictions,\n references=batch['labels'],\n )\n result = metric.compute()\n\n if validation_dataloader2:\n for batch in validation_dataloader2:\n batch.pop('teacher_logits', None)\n batch.to(device)\n outputs = model(**batch)\n predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()\n metric.add_batch(\n predictions=predictions,\n references=batch['labels'],\n )\n result = {'matched': result, 'mismatched': metric.compute()}\n result['default'] = (result['matched']['accuracy'] + result['mismatched']['accuracy']) / 2\n else:\n result['default'] = result.get('f1', result.get('accuracy', None))\n\n model.train(training)\n return result\n\n# using huggingface native loss\ndef fake_criterion(outputs, targets):\n return outputs" "import functools\nimport time\n\nimport torch.nn.functional as F\nfrom datasets import load_metric\nfrom transformers.modeling_outputs import SequenceClassifierOutput\n\n\ndef training(model: torch.nn.Module,\n optimizer: torch.optim.Optimizer,\n criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,\n max_steps: int = None,\n max_epochs: int = None,\n train_dataloader: DataLoader = None,\n distillation: bool = False,\n teacher_model: torch.nn.Module = None,\n distil_func: Callable = None,\n log_path: str = Path(log_dir) / 'training.log',\n save_best_model: bool = False,\n save_path: str = None,\n evaluation_func: Callable = None,\n eval_per_steps: int = 1000,\n device=None):\n\n assert train_dataloader is not None\n\n model.train()\n if teacher_model is not None:\n teacher_model.eval()\n current_step = 0\n best_result = 0\n\n total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3\n total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)\n\n print(f'Training {total_epochs} epochs, {total_steps} steps...')\n\n for current_epoch in range(total_epochs):\n for batch in train_dataloader:\n if current_step >= total_steps:\n return\n batch.to(device)\n outputs = model(**batch)\n loss = outputs.loss\n\n if distillation:\n assert teacher_model is not None\n with torch.no_grad():\n teacher_outputs = teacher_model(**batch)\n distil_loss = distil_func(outputs, teacher_outputs)\n loss = 0.1 * loss + 0.9 * distil_loss\n\n loss = criterion(loss, None)\n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n\n # per step schedule\n if lr_scheduler:\n lr_scheduler.step()\n\n current_step += 1\n\n if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:\n result = evaluation_func(model) if evaluation_func else None\n with (log_path).open('a+') as f:\n msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)\n f.write(msg)\n # if it's the best model, save it.\n if save_best_model and (result is None or best_result < result['default']):\n assert save_path is not None\n torch.save(model.state_dict(), save_path)\n best_result = None if result is None else result['default']\n\n\ndef distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):\n encoder_hidden_state_loss = []\n for i, idx in enumerate(encoder_layer_idxs[:-1]):\n encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))\n logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)\n\n distil_loss = 0\n for loss in encoder_hidden_state_loss:\n distil_loss += loss\n distil_loss += logits_loss\n return distil_loss\n\n\ndef evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):\n assert validation_dataloaders is not None\n training = model.training\n model.eval()\n\n is_regression = task_name == 'stsb'\n metric = load_metric('glue', task_name)\n\n result = {}\n default_result = 0\n for val_name, validation_dataloader in validation_dataloaders.items():\n for batch in validation_dataloader:\n batch.to(device)\n outputs = model(**batch)\n predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()\n metric.add_batch(\n predictions=predictions,\n references=batch['labels'],\n )\n result[val_name] = metric.compute()\n default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))\n result['default'] = default_result / len(result)\n\n model.train(training)\n return result\n\n\nevaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)\n\n\ndef fake_criterion(loss, _):\n return loss"
] ]
}, },
{ {
...@@ -98,14 +98,14 @@ ...@@ -98,14 +98,14 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import functools\n\nfrom torch.optim import Adam\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom transformers import BertForSequenceClassification\n\ndef create_pretrained_model():\n is_regression = task_name == 'stsb'\n num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)\n return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)\n\ndef create_finetuned_model():\n pretrained_model = create_pretrained_model().to(device)\n\n train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()\n evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)\n steps_per_epoch = len(train_dataloader)\n training_epochs = 3\n\n finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'\n\n if finetuned_model_state_path.exists():\n pretrained_model.load_state_dict(torch.load(finetuned_model_state_path))\n elif dev_mode:\n pass\n else:\n optimizer = Adam(pretrained_model.parameters(), lr=3e-5, eps=1e-8)\n\n def lr_lambda(current_step: int):\n return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))\n\n lr_scheduler = LambdaLR(optimizer, lr_lambda)\n training(train_dataloader, pretrained_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=training_epochs,\n save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func)\n return pretrained_model\n\nfinetuned_model = create_finetuned_model()" "from torch.optim import Adam\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom transformers import BertForSequenceClassification\n\n\ndef create_pretrained_model():\n is_regression = task_name == 'stsb'\n num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)\n model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)\n model.bert.config.output_hidden_states = True\n return model\n\n\ndef create_finetuned_model():\n finetuned_model = create_pretrained_model()\n finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'\n\n if finetuned_model_state_path.exists():\n finetuned_model.load_state_dict(torch.load(finetuned_model_state_path, map_location='cpu'))\n finetuned_model.to(device)\n elif dev_mode:\n pass\n else:\n steps_per_epoch = len(train_dataloader)\n training_epochs = 3\n optimizer = Adam(finetuned_model.parameters(), lr=3e-5, eps=1e-8)\n\n def lr_lambda(current_step: int):\n return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))\n\n lr_scheduler = LambdaLR(optimizer, lr_lambda)\n training(finetuned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,\n max_epochs=training_epochs, train_dataloader=train_dataloader, log_path=log_dir / 'finetuning_on_downstream.log',\n save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func, device=device)\n return finetuned_model\n\n\nfinetuned_model = create_finetuned_model()"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Using finetuned model as teacher model to create dataloader.\nAdd 'teacher_logits' to dataset, it is used to do the distillation, it can be seen as a kind of data label.\n\n" "### Pruning\nAccording to experience, it is easier to achieve good results by pruning the attention part and the FFN part in stages.\nOf course, pruning together can also achieve the similar effect, but more parameter adjustment attempts are required.\nSo in this section, we do pruning in stages.\n\nFirst, we prune the attention layer with MovementPruner.\n\n"
] ]
}, },
{ {
...@@ -116,14 +116,14 @@ ...@@ -116,14 +116,14 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"if not dev_mode:\n train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data(teacher_model=finetuned_model)\nelse:\n train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()\n\nevaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)" "steps_per_epoch = len(train_dataloader)\n\n# Set training steps/epochs for pruning.\n\nif not dev_mode:\n total_epochs = 4\n total_steps = total_epochs * steps_per_epoch\n warmup_steps = 1 * steps_per_epoch\n cooldown_steps = 1 * steps_per_epoch\nelse:\n total_epochs = 1\n total_steps = 3\n warmup_steps = 1\n cooldown_steps = 1\n\n# Initialize evaluator used by MovementPruner.\n\nimport nni\nfrom nni.algorithms.compression.v2.pytorch import TorchEvaluator\n\nmovement_training = functools.partial(training, train_dataloader=train_dataloader,\n log_path=log_dir / 'movement_pruning.log',\n evaluation_func=evaluation_func, device=device)\ntraced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)\n\ndef lr_lambda(current_step: int):\n if current_step < warmup_steps:\n return float(current_step) / warmup_steps\n return max(0.0, float(total_steps - current_step) / float(total_steps - warmup_steps))\n\ntraced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)\nevaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)\n\n# Apply block-soft-movement pruning on attention layers.\n# Note that block sparse is introduced by `sparse_granularity='auto'`, and only support `bert`, `bart`, `t5` right now.\n\nfrom nni.compression.pytorch.pruning import MovementPruner\n\nconfig_list = [{\n 'op_types': ['Linear'],\n 'op_partial_names': ['bert.encoder.layer.{}.attention'.format(i) for i in range(layers_num)],\n 'sparsity': 0.1\n}]\n\npruner = MovementPruner(model=finetuned_model,\n config_list=config_list,\n evaluator=evaluator,\n training_epochs=total_epochs,\n training_steps=total_steps,\n warm_up_step=warmup_steps,\n cool_down_beginning_step=total_steps - cooldown_steps,\n regular_scale=10,\n movement_mode='soft',\n sparse_granularity='auto')\n_, attention_masks = pruner.compress()\npruner.show_pruned_weights()\n\ntorch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Pruning\nFirst, using MovementPruner to prune attention head.\n\n" "Load a new finetuned model to do speedup, you can think of this as using the finetuned state to initialize the pruned model weights.\nNote that nni speedup don't support replacing attention module, so here we manully replace the attention module.\n\nIf the head is entire masked, physically prune it and create config_list for FFN pruning.\n\n"
] ]
}, },
{ {
...@@ -134,25 +134,7 @@ ...@@ -134,25 +134,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"steps_per_epoch = len(train_dataloader)\n\n# Set training steps/epochs for pruning.\n\nif not dev_mode:\n total_epochs = 4\n total_steps = total_epochs * steps_per_epoch\n warmup_steps = 1 * steps_per_epoch\n cooldown_steps = 1 * steps_per_epoch\nelse:\n total_epochs = 1\n total_steps = 3\n warmup_steps = 1\n cooldown_steps = 1\n\n# Initialize evaluator used by MovementPruner.\n\nimport nni\nfrom nni.algorithms.compression.v2.pytorch import TorchEvaluator\n\nmovement_training = functools.partial(training, train_dataloader, log_path=log_dir / 'movement_pruning.log',\n evaluation_func=evaluation_func)\ntraced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)\n\ndef lr_lambda(current_step: int):\n if current_step < warmup_steps:\n return float(current_step) / warmup_steps\n return max(0.0, float(total_steps - current_step) / float(total_steps - warmup_steps))\n\ntraced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)\nevaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)\n\n# Apply block-soft-movement pruning on attention layers.\n\nfrom nni.compression.pytorch.pruning import MovementPruner\n\nconfig_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder.layer.{}.'.format(i) for i in range(layers_num)], 'sparsity': 0.1}]\npruner = MovementPruner(model=finetuned_model,\n config_list=config_list,\n evaluator=evaluator,\n training_epochs=total_epochs,\n training_steps=total_steps,\n warm_up_step=warmup_steps,\n cool_down_beginning_step=total_steps - cooldown_steps,\n regular_scale=10,\n movement_mode='soft',\n sparse_granularity='auto')\n_, attention_masks = pruner.compress()\npruner.show_pruned_weights()\n\ntorch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')" "attention_pruned_model = create_finetuned_model().to(device)\nattention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')\n\nffn_config_list = []\nlayer_remained_idxs = []\nmodule_list = []\nfor i in range(0, layers_num):\n prefix = f'bert.encoder.layer.{i}.'\n value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']\n head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)\n head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()\n print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')\n if len(head_idxs) != heads_num:\n attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)\n module_list.append(attention_pruned_model.bert.encoder.layer[i])\n # The final ffn weight remaining ratio is the half of the attention weight remaining ratio.\n # This is just an empirical configuration, you can use any other method to determine this sparsity.\n sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5\n # here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.\n sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)\n ffn_config_list.append({\n 'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],\n 'sparsity': sparsity_per_iter\n })\n layer_remained_idxs.append(i)\n\nattention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)\ndistil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load a new finetuned model to do the speedup.\nNote that nni speedup don't support replace attention module, so here we manully replace the attention module.\n\nIf the head is entire masked, physically prune it and create config_list for FFN pruning.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"attention_pruned_model = create_finetuned_model().to(device)\nattention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')\n\nffn_config_list = []\nlayer_count = 0\nmodule_list = []\nfor i in range(0, layers_num):\n prefix = f'bert.encoder.layer.{i}.'\n value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']\n head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)\n head_idx = torch.arange(len(head_mask))[head_mask].long().tolist()\n print(f'layer {i} pruner {len(head_idx)} head: {head_idx}')\n if len(head_idx) != heads_num:\n attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idx)\n module_list.append(attention_pruned_model.bert.encoder.layer[i])\n # The final ffn weight remaining ratio is the half of the attention weight remaining ratio.\n # This is just an empirical configuration, you can use any other method to determine this sparsity.\n sparsity = 1 - (1 - len(head_idx) / heads_num) * 0.5\n # here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.\n sparsity_per_iter = 1 - (1 - sparsity) ** (1 / heads_num)\n ffn_config_list.append({'op_names': [f'bert.encoder.layer.{layer_count}.intermediate.dense'], 'sparsity': sparsity_per_iter})\n layer_count += 1\n\nattention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)"
] ]
}, },
{ {
...@@ -170,14 +152,14 @@ ...@@ -170,14 +152,14 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"if not dev_mode:\n total_epochs = 5\n total_steps = None\n distillation = True\nelse:\n total_epochs = 1\n total_steps = 1\n distillation = False\n\noptimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)\n\ndef lr_lambda(current_step: int):\n return max(0.0, float(total_epochs * steps_per_epoch - current_step) / float(total_epochs * steps_per_epoch))\n\nlr_scheduler = LambdaLR(optimizer, lr_lambda)\nat_model_save_path = log_dir / 'attention_pruned_model_state.pth'\ntraining(train_dataloader, attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,\n max_epochs=total_epochs, max_steps=total_steps, save_best_model=True, save_path=at_model_save_path,\n distillation=distillation, evaluation_func=evaluation_func)\n\nif not dev_mode:\n attention_pruned_model.load_state_dict(torch.load(at_model_save_path))" "if not dev_mode:\n total_epochs = 5\n total_steps = None\n distillation = True\nelse:\n total_epochs = 1\n total_steps = 1\n distillation = False\n\nteacher_model = create_finetuned_model()\noptimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)\n\ndef lr_lambda(current_step: int):\n return max(0.0, float(total_epochs * steps_per_epoch - current_step) / float(total_epochs * steps_per_epoch))\n\nlr_scheduler = LambdaLR(optimizer, lr_lambda)\nat_model_save_path = log_dir / 'attention_pruned_model_state.pth'\ntraining(attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=total_epochs,\n max_steps=total_steps, train_dataloader=train_dataloader, distillation=distillation, teacher_model=teacher_model,\n distil_func=distil_func, log_path=log_dir / 'retraining.log', save_best_model=True, save_path=at_model_save_path,\n evaluation_func=evaluation_func, device=device)\n\nif not dev_mode:\n attention_pruned_model.load_state_dict(torch.load(at_model_save_path))"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations.\nFinetuning 2000 steps after each iteration, then finetuning 2 epochs after pruning finished.\n\nNNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.\n\n" "Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations.\nFinetuning 3000 steps after each pruning iteration, then finetuning 2 epochs after pruning finished.\n\nNNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.\n\n"
] ]
}, },
{ {
...@@ -188,14 +170,14 @@ ...@@ -188,14 +170,14 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"if not dev_mode:\n total_epochs = 4\n total_steps = None\n taylor_pruner_steps = 1000\n steps_per_iteration = 2000\n total_pruning_steps = 24000\n distillation = True\nelse:\n total_epochs = 1\n total_steps = 6\n taylor_pruner_steps = 2\n steps_per_iteration = 2\n total_pruning_steps = 4\n distillation = False\n\nfrom nni.compression.pytorch.pruning import TaylorFOWeightPruner\nfrom nni.compression.pytorch.speedup import ModelSpeedup\n\ndistil_training = functools.partial(training, train_dataloader, log_path=log_dir / 'taylor_pruning.log',\n distillation=distillation, evaluation_func=evaluation_func)\ntraced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)\nevaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)\n\ncurrent_step = 0\nbest_result = 0\ninit_lr = 3e-5\n\ndummy_input = torch.rand(8, 128, 768).to(device)\n\nattention_pruned_model.train()\nfor current_epoch in range(total_epochs):\n for batch in train_dataloader:\n if total_steps and current_step >= total_steps:\n break\n # pruning 12 times\n if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:\n check_point = attention_pruned_model.state_dict()\n pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)\n _, ffn_masks = pruner.compress()\n renamed_ffn_masks = {}\n # rename the masks keys, because we only speedup the bert.encoder\n for model_name, targets_mask in ffn_masks.items():\n renamed_ffn_masks[model_name.split('bert.encoder.')[1]] = targets_mask\n pruner._unwrap_model()\n attention_pruned_model.load_state_dict(check_point)\n ModelSpeedup(attention_pruned_model.bert.encoder, dummy_input, renamed_ffn_masks).speedup_model()\n optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)\n\n batch.to(device)\n teacher_logits = batch.pop('teacher_logits', None)\n optimizer.zero_grad()\n\n # manually schedule lr\n for params_group in optimizer.param_groups:\n params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr\n\n outputs = attention_pruned_model(**batch)\n loss = outputs.loss\n\n # distillation\n if teacher_logits is not None:\n distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1),\n F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)\n loss = 0.1 * loss + 0.9 * distil_loss\n loss.backward()\n optimizer.step()\n\n current_step += 1\n if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:\n result = evaluation_func(attention_pruned_model)\n with (log_dir / 'ffn_pruning.log').open('a+') as f:\n msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())),\n current_epoch, current_step, result)\n f.write(msg)\n if current_step >= total_pruning_steps and best_result < result['default']:\n torch.save(attention_pruned_model, log_dir / 'best_model.pth')\n best_result = result['default']" "if not dev_mode:\n total_epochs = 7\n total_steps = None\n taylor_pruner_steps = 1000\n steps_per_iteration = 3000\n total_pruning_steps = 36000\n distillation = True\nelse:\n total_epochs = 1\n total_steps = 6\n taylor_pruner_steps = 2\n steps_per_iteration = 2\n total_pruning_steps = 4\n distillation = False\n\nfrom nni.compression.pytorch.pruning import TaylorFOWeightPruner\nfrom nni.compression.pytorch.speedup import ModelSpeedup\n\ndistil_training = functools.partial(training, train_dataloader=train_dataloader, distillation=distillation,\n teacher_model=teacher_model, distil_func=distil_func, device=device)\ntraced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)\nevaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)\n\ncurrent_step = 0\nbest_result = 0\ninit_lr = 3e-5\n\ndummy_input = torch.rand(8, 128, 768).to(device)\n\nattention_pruned_model.train()\nfor current_epoch in range(total_epochs):\n for batch in train_dataloader:\n if total_steps and current_step >= total_steps:\n break\n # pruning with TaylorFOWeightPruner & reinitialize optimizer\n if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:\n check_point = attention_pruned_model.state_dict()\n pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)\n _, ffn_masks = pruner.compress()\n renamed_ffn_masks = {}\n # rename the masks keys, because we only speedup the bert.encoder\n for model_name, targets_mask in ffn_masks.items():\n renamed_ffn_masks[model_name.split('bert.encoder.')[1]] = targets_mask\n pruner._unwrap_model()\n attention_pruned_model.load_state_dict(check_point)\n ModelSpeedup(attention_pruned_model.bert.encoder, dummy_input, renamed_ffn_masks).speedup_model()\n optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)\n\n batch.to(device)\n # manually schedule lr\n for params_group in optimizer.param_groups:\n params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr\n\n outputs = attention_pruned_model(**batch)\n loss = outputs.loss\n\n # distillation\n if distillation:\n assert teacher_model is not None\n with torch.no_grad():\n teacher_outputs = teacher_model(**batch)\n distil_loss = distil_func(outputs, teacher_outputs)\n loss = 0.1 * loss + 0.9 * distil_loss\n\n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n\n current_step += 1\n\n if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:\n result = evaluation_func(attention_pruned_model)\n with (log_dir / 'ffn_pruning.log').open('a+') as f:\n msg = '[{}] Epoch {}, Step {}: {}\\n'.format(time.asctime(time.localtime(time.time())),\n current_epoch, current_step, result)\n f.write(msg)\n if current_step >= total_pruning_steps and best_result < result['default']:\n torch.save(attention_pruned_model, log_dir / 'best_model.pth')\n best_result = result['default']"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Result\nThe speedup is test on the entire validation dataset with batch size 32 on A100.\nWe test under two pytorch version and found the latency varying widely.\n\nSetting 1: pytorch 1.12.1\n\nSetting 2: pytorch 1.10.0\n\n.. list-table:: Prune Bert-base-uncased on MNLI\n :header-rows: 1\n :widths: auto\n\n * - Attention Pruning Method\n - FFN Pruning Method\n - Total Sparsity\n - Accuracy\n - Acc. Drop\n - Speedup (S1)\n - Speedup (S2)\n * -\n -\n - 0%\n - 84.73 / 84.63\n - +0.0 / +0.0\n - 12.56s (x1.00)\n - 4.05s (x1.00)\n * - `movement-pruner` (soft, th=0.1, lambda=5)\n - `taylor-fo-weight-pruner`\n - 51.39%\n - 84.25 / 84.96\n - -0.48 / +0.33\n - 6.85s (x1.83)\n - 2.7s (x1.50)\n * - `movement-pruner` (soft, th=0.1, lambda=10)\n - `taylor-fo-weight-pruner`\n - 66.67%\n - 83.98 / 83.75\n - -0.75 / -0.88\n - 4.73s (x2.66)\n - 2.16s (x1.86)\n * - `movement-pruner` (soft, th=0.1, lambda=20)\n - `taylor-fo-weight-pruner`\n - 77.78%\n - 83.02 / 83.06\n - -1.71 / -1.57\n - 3.35s (x3.75)\n - 1.72s (x2.35)\n * - `movement-pruner` (soft, th=0.1, lambda=30)\n - `taylor-fo-weight-pruner`\n - 87.04%\n - 81.24 / 80.99\n - -3.49 / -3.64\n - 2.19s (x5.74)\n - 1.31s (x3.09)\n\n" "## Result\nThe speedup is test on the entire validation dataset with batch size 32 on A100.\nWe test under two pytorch version and found the latency varying widely.\n\nSetting 1: pytorch 1.12.1\n\nSetting 2: pytorch 1.10.0\n\n.. list-table:: Prune Bert-base-uncased on MNLI\n :header-rows: 1\n :widths: auto\n\n * - Attention Pruning Method\n - FFN Pruning Method\n - Total Sparsity\n - Accuracy\n - Acc. Drop\n - Speedup (S1)\n - Speedup (S2)\n * -\n -\n - 0%\n - 84.73 / 84.63\n - +0.0 / +0.0\n - 12.56s (x1.00)\n - 4.05s (x1.00)\n * - `movement-pruner` (soft, sparsity=0.1, regular_scale=5)\n - `taylor-fo-weight-pruner`\n - 51.39%\n - 84.25 / 84.96\n - -0.48 / +0.33\n - 6.85s (x1.83)\n - 2.7s (x1.50)\n * - `movement-pruner` (soft, sparsity=0.1, regular_scale=10)\n - `taylor-fo-weight-pruner`\n - 66.67%\n - 83.98 / 83.75\n - -0.75 / -0.88\n - 4.73s (x2.66)\n - 2.16s (x1.86)\n * - `movement-pruner` (soft, sparsity=0.1, regular_scale=20)\n - `taylor-fo-weight-pruner`\n - 77.78%\n - 83.02 / 83.06\n - -1.71 / -1.57\n - 3.35s (x3.75)\n - 1.72s (x2.35)\n * - `movement-pruner` (soft, sparsity=0.1, regular_scale=30)\n - `taylor-fo-weight-pruner`\n - 87.04%\n - 81.24 / 80.99\n - -3.49 / -3.64\n - 2.19s (x5.74)\n - 1.31s (x3.09)\n\n"
] ]
} }
], ],
...@@ -215,7 +197,7 @@ ...@@ -215,7 +197,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.13" "version": "3.8.13"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
""" """
Pruning Transformer with NNI Pruning Bert on Task MNLI
============================ =========================
Workable Pruning Process Workable Pruning Process
------------------------ ------------------------
...@@ -32,11 +32,18 @@ During the process of pruning transformer, we gained some of the following exper ...@@ -32,11 +32,18 @@ During the process of pruning transformer, we gained some of the following exper
Experiment Experiment
---------- ----------
The complete pruning process will take about 8 hours on one A100.
Preparation Preparation
^^^^^^^^^^^ ^^^^^^^^^^^
Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.
The complete pruning process takes about 8 hours on one A100. This section is mainly to get a finetuned model on the downstream task.
If you are familiar with how to finetune Bert on GLUE dataset, you can skip this section.
.. note::
Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.
""" """
dev_mode = True dev_mode = True
...@@ -45,11 +52,11 @@ dev_mode = True ...@@ -45,11 +52,11 @@ dev_mode = True
# Some basic setting. # Some basic setting.
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable, Dict
pretrained_model_name_or_path = 'bert-base-uncased' pretrained_model_name_or_path = 'bert-base-uncased'
task_name = 'mnli' task_name = 'mnli'
experiment_id = 'pruning_bert' experiment_id = 'pruning_bert_mnli'
# heads_num and layers_num should align with pretrained_model_name_or_path # heads_num and layers_num should align with pretrained_model_name_or_path
heads_num = 12 heads_num = 12
...@@ -63,6 +70,11 @@ log_dir.mkdir(parents=True, exist_ok=True) ...@@ -63,6 +70,11 @@ log_dir.mkdir(parents=True, exist_ok=True)
model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}') model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')
model_dir.mkdir(parents=True, exist_ok=True) model_dir.mkdir(parents=True, exist_ok=True)
# used to save GLUE data
data_dir = Path(f'./data')
data_dir.mkdir(parents=True, exist_ok=True)
# set seed
from transformers import set_seed from transformers import set_seed
set_seed(1024) set_seed(1024)
...@@ -70,8 +82,7 @@ import torch ...@@ -70,8 +82,7 @@ import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# %% # %%
# The function used to create dataloaders, note that 'mnli' has two evaluation dataset. # Create dataloaders.
# If teacher_model is set, will run all dataset on teacher model to get the 'teacher_logits' for distillation.
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -90,8 +101,7 @@ task_to_keys = { ...@@ -90,8 +101,7 @@ task_to_keys = {
'wnli': ('sentence1', 'sentence2'), 'wnli': ('sentence1', 'sentence2'),
} }
def prepare_data(cache_dir='./data', train_batch_size=32, eval_batch_size=32, def prepare_dataloaders(cache_dir=data_dir, train_batch_size=32, eval_batch_size=32):
teacher_model: torch.nn.Module = None):
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path) tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)
sentence1_key, sentence2_key = task_to_keys[task_name] sentence1_key, sentence2_key = task_to_keys[task_name]
data_collator = DataCollatorWithPadding(tokenizer) data_collator = DataCollatorWithPadding(tokenizer)
...@@ -117,124 +127,132 @@ def prepare_data(cache_dir='./data', train_batch_size=32, eval_batch_size=32, ...@@ -117,124 +127,132 @@ def prepare_data(cache_dir='./data', train_batch_size=32, eval_batch_size=32,
processed_datasets = raw_datasets.map(preprocess_function, batched=True, processed_datasets = raw_datasets.map(preprocess_function, batched=True,
remove_columns=raw_datasets['train'].column_names) remove_columns=raw_datasets['train'].column_names)
# if has teacher model, add 'teacher_logits' to datasets who has 'labels'. train_dataset = processed_datasets['train']
# 'teacher_logits' is used for distillation and avoid the double counting. if task_name == 'mnli':
if teacher_model: validation_datasets = {
teacher_model_training = teacher_model.training 'validation_matched': processed_datasets['validation_matched'],
teacher_model.eval() 'validation_mismatched': processed_datasets['validation_mismatched']
model_device = next(teacher_model.parameters()).device }
else:
validation_datasets = {
'validation': processed_datasets['validation']
}
def add_teacher_logits(examples): train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
result = {k: v for k, v in examples.items()} validation_dataloaders = {
samples = data_collator(result).to(model_device) val_name: DataLoader(val_dataset, collate_fn=data_collator, batch_size=eval_batch_size) \
if 'labels' in samples: for val_name, val_dataset in validation_datasets.items()
with torch.no_grad(): }
logits = teacher_model(**samples).logits.tolist()
result['teacher_logits'] = logits
return result
processed_datasets = processed_datasets.map(add_teacher_logits, batched=True, return train_dataloader, validation_dataloaders
batch_size=train_batch_size)
teacher_model.train(teacher_model_training)
train_dataset = processed_datasets['train']
validation_dataset = processed_datasets['validation_matched' if task_name == 'mnli' else 'validation'] train_dataloader, validation_dataloaders = prepare_dataloaders()
validation_dataset2 = processed_datasets['validation_mismatched'] if task_name == 'mnli' else None
train_dataloader = DataLoader(train_dataset,
shuffle=True,
collate_fn=data_collator,
batch_size=train_batch_size)
validation_dataloader = DataLoader(validation_dataset,
collate_fn=data_collator,
batch_size=eval_batch_size)
validation_dataloader2 = DataLoader(validation_dataset2,
collate_fn=data_collator,
batch_size=eval_batch_size) if task_name == 'mnli' else None
return train_dataloader, validation_dataloader, validation_dataloader2
# %% # %%
# Training function & evaluation function. # Training function & evaluation function.
import functools
import time import time
import torch.nn.functional as F import torch.nn.functional as F
from datasets import load_metric from datasets import load_metric
from transformers.modeling_outputs import SequenceClassifierOutput
def training(train_dataloader: DataLoader,
model: torch.nn.Module, def training(model: torch.nn.Module,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
max_steps: int = None, max_epochs: int = None, max_steps: int = None,
save_best_model: bool = False, save_path: str = None, max_epochs: int = None,
log_path: str = Path(log_dir) / 'training.log', train_dataloader: DataLoader = None,
distillation: bool = False, distillation: bool = False,
evaluation_func=None): teacher_model: torch.nn.Module = None,
distil_func: Callable = None,
log_path: str = Path(log_dir) / 'training.log',
save_best_model: bool = False,
save_path: str = None,
evaluation_func: Callable = None,
eval_per_steps: int = 1000,
device=None):
assert train_dataloader is not None
model.train() model.train()
if teacher_model is not None:
teacher_model.eval()
current_step = 0 current_step = 0
best_result = 0 best_result = 0
for current_epoch in range(max_epochs if max_epochs else 1): total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3
total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)
print(f'Training {total_epochs} epochs, {total_steps} steps...')
for current_epoch in range(total_epochs):
for batch in train_dataloader: for batch in train_dataloader:
if current_step >= total_steps:
return
batch.to(device) batch.to(device)
teacher_logits = batch.pop('teacher_logits', None)
optimizer.zero_grad()
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
if distillation: if distillation:
assert teacher_logits is not None assert teacher_model is not None
distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1), with torch.no_grad():
F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2) teacher_outputs = teacher_model(**batch)
distil_loss = distil_func(outputs, teacher_outputs)
loss = 0.1 * loss + 0.9 * distil_loss loss = 0.1 * loss + 0.9 * distil_loss
loss = criterion(loss, None) loss = criterion(loss, None)
optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# per step schedule
if lr_scheduler: if lr_scheduler:
lr_scheduler.step() lr_scheduler.step()
current_step += 1 current_step += 1
# evaluation for every 1000 steps if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
result = evaluation_func(model) if evaluation_func else None result = evaluation_func(model) if evaluation_func else None
with (log_path).open('a+') as f: with (log_path).open('a+') as f:
msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result) msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)
f.write(msg) f.write(msg)
# if it's the best model, save it. # if it's the best model, save it.
if save_best_model and best_result < result['default']: if save_best_model and (result is None or best_result < result['default']):
assert save_path is not None assert save_path is not None
torch.save(model.state_dict(), save_path) torch.save(model.state_dict(), save_path)
best_result = result['default'] best_result = None if result is None else result['default']
if max_steps and current_step >= max_steps:
return
def evaluation(validation_dataloader: DataLoader, def distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):
validation_dataloader2: DataLoader, encoder_hidden_state_loss = []
model: torch.nn.Module): for i, idx in enumerate(encoder_layer_idxs[:-1]):
encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))
logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)
distil_loss = 0
for loss in encoder_hidden_state_loss:
distil_loss += loss
distil_loss += logits_loss
return distil_loss
def evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):
assert validation_dataloaders is not None
training = model.training training = model.training
model.eval() model.eval()
is_regression = task_name == 'stsb' is_regression = task_name == 'stsb'
metric = load_metric('glue', task_name) metric = load_metric('glue', task_name)
result = {}
default_result = 0
for val_name, validation_dataloader in validation_dataloaders.items():
for batch in validation_dataloader: for batch in validation_dataloader:
batch.pop('teacher_logits', None)
batch.to(device)
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
metric.add_batch(
predictions=predictions,
references=batch['labels'],
)
result = metric.compute()
if validation_dataloader2:
for batch in validation_dataloader2:
batch.pop('teacher_logits', None)
batch.to(device) batch.to(device)
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
...@@ -242,75 +260,71 @@ def evaluation(validation_dataloader: DataLoader, ...@@ -242,75 +260,71 @@ def evaluation(validation_dataloader: DataLoader,
predictions=predictions, predictions=predictions,
references=batch['labels'], references=batch['labels'],
) )
result = {'matched': result, 'mismatched': metric.compute()} result[val_name] = metric.compute()
result['default'] = (result['matched']['accuracy'] + result['mismatched']['accuracy']) / 2 default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))
else: result['default'] = default_result / len(result)
result['default'] = result.get('f1', result.get('accuracy', None))
model.train(training) model.train(training)
return result return result
# using huggingface native loss
def fake_criterion(outputs, targets):
return outputs
evaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)
def fake_criterion(loss, _):
return loss
# %% # %%
# Prepare pre-trained model and finetuning on downstream task. # Prepare pre-trained model and finetuning on downstream task.
import functools
from torch.optim import Adam from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from transformers import BertForSequenceClassification from transformers import BertForSequenceClassification
def create_pretrained_model(): def create_pretrained_model():
is_regression = task_name == 'stsb' is_regression = task_name == 'stsb'
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2) num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels) model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
model.bert.config.output_hidden_states = True
return model
def create_finetuned_model():
pretrained_model = create_pretrained_model().to(device)
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()
evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)
steps_per_epoch = len(train_dataloader)
training_epochs = 3
def create_finetuned_model():
finetuned_model = create_pretrained_model()
finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth' finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'
if finetuned_model_state_path.exists(): if finetuned_model_state_path.exists():
pretrained_model.load_state_dict(torch.load(finetuned_model_state_path)) finetuned_model.load_state_dict(torch.load(finetuned_model_state_path, map_location='cpu'))
finetuned_model.to(device)
elif dev_mode: elif dev_mode:
pass pass
else: else:
optimizer = Adam(pretrained_model.parameters(), lr=3e-5, eps=1e-8) steps_per_epoch = len(train_dataloader)
training_epochs = 3
optimizer = Adam(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int): def lr_lambda(current_step: int):
return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch)) return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))
lr_scheduler = LambdaLR(optimizer, lr_lambda) lr_scheduler = LambdaLR(optimizer, lr_lambda)
training(train_dataloader, pretrained_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=training_epochs, training(finetuned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,
save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func) max_epochs=training_epochs, train_dataloader=train_dataloader, log_path=log_dir / 'finetuning_on_downstream.log',
return pretrained_model save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func, device=device)
return finetuned_model
finetuned_model = create_finetuned_model()
# %% finetuned_model = create_finetuned_model()
# Using finetuned model as teacher model to create dataloader.
# Add 'teacher_logits' to dataset, it is used to do the distillation, it can be seen as a kind of data label.
if not dev_mode:
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data(teacher_model=finetuned_model)
else:
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()
evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)
# %% # %%
# Pruning # Pruning
# ^^^^^^^ # ^^^^^^^
# First, using MovementPruner to prune attention head. # According to experience, it is easier to achieve good results by pruning the attention part and the FFN part in stages.
# Of course, pruning together can also achieve the similar effect, but more parameter adjustment attempts are required.
# So in this section, we do pruning in stages.
#
# First, we prune the attention layer with MovementPruner.
steps_per_epoch = len(train_dataloader) steps_per_epoch = len(train_dataloader)
...@@ -332,8 +346,9 @@ else: ...@@ -332,8 +346,9 @@ else:
import nni import nni
from nni.algorithms.compression.v2.pytorch import TorchEvaluator from nni.algorithms.compression.v2.pytorch import TorchEvaluator
movement_training = functools.partial(training, train_dataloader, log_path=log_dir / 'movement_pruning.log', movement_training = functools.partial(training, train_dataloader=train_dataloader,
evaluation_func=evaluation_func) log_path=log_dir / 'movement_pruning.log',
evaluation_func=evaluation_func, device=device)
traced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8) traced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int): def lr_lambda(current_step: int):
...@@ -345,10 +360,16 @@ traced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda) ...@@ -345,10 +360,16 @@ traced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)
evaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler) evaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)
# Apply block-soft-movement pruning on attention layers. # Apply block-soft-movement pruning on attention layers.
# Note that block sparse is introduced by `sparse_granularity='auto'`, and only support `bert`, `bart`, `t5` right now.
from nni.compression.pytorch.pruning import MovementPruner from nni.compression.pytorch.pruning import MovementPruner
config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder.layer.{}.'.format(i) for i in range(layers_num)], 'sparsity': 0.1}] config_list = [{
'op_types': ['Linear'],
'op_partial_names': ['bert.encoder.layer.{}.attention'.format(i) for i in range(layers_num)],
'sparsity': 0.1
}]
pruner = MovementPruner(model=finetuned_model, pruner = MovementPruner(model=finetuned_model,
config_list=config_list, config_list=config_list,
evaluator=evaluator, evaluator=evaluator,
...@@ -365,8 +386,8 @@ pruner.show_pruned_weights() ...@@ -365,8 +386,8 @@ pruner.show_pruned_weights()
torch.save(attention_masks, Path(log_dir) / 'attention_masks.pth') torch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')
# %% # %%
# Load a new finetuned model to do the speedup. # Load a new finetuned model to do speedup, you can think of this as using the finetuned state to initialize the pruned model weights.
# Note that nni speedup don't support replace attention module, so here we manully replace the attention module. # Note that nni speedup don't support replacing attention module, so here we manully replace the attention module.
# #
# If the head is entire masked, physically prune it and create config_list for FFN pruning. # If the head is entire masked, physically prune it and create config_list for FFN pruning.
...@@ -374,26 +395,30 @@ attention_pruned_model = create_finetuned_model().to(device) ...@@ -374,26 +395,30 @@ attention_pruned_model = create_finetuned_model().to(device)
attention_masks = torch.load(Path(log_dir) / 'attention_masks.pth') attention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')
ffn_config_list = [] ffn_config_list = []
layer_count = 0 layer_remained_idxs = []
module_list = [] module_list = []
for i in range(0, layers_num): for i in range(0, layers_num):
prefix = f'bert.encoder.layer.{i}.' prefix = f'bert.encoder.layer.{i}.'
value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight'] value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.) head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
head_idx = torch.arange(len(head_mask))[head_mask].long().tolist() head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()
print(f'layer {i} pruner {len(head_idx)} head: {head_idx}') print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')
if len(head_idx) != heads_num: if len(head_idxs) != heads_num:
attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idx) attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)
module_list.append(attention_pruned_model.bert.encoder.layer[i]) module_list.append(attention_pruned_model.bert.encoder.layer[i])
# The final ffn weight remaining ratio is the half of the attention weight remaining ratio. # The final ffn weight remaining ratio is the half of the attention weight remaining ratio.
# This is just an empirical configuration, you can use any other method to determine this sparsity. # This is just an empirical configuration, you can use any other method to determine this sparsity.
sparsity = 1 - (1 - len(head_idx) / heads_num) * 0.5 sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5
# here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`. # here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.
sparsity_per_iter = 1 - (1 - sparsity) ** (1 / heads_num) sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)
ffn_config_list.append({'op_names': [f'bert.encoder.layer.{layer_count}.intermediate.dense'], 'sparsity': sparsity_per_iter}) ffn_config_list.append({
layer_count += 1 'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],
'sparsity': sparsity_per_iter
})
layer_remained_idxs.append(i)
attention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list) attention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)
distil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)
# %% # %%
# Retrain the attention pruned model with distillation. # Retrain the attention pruned model with distillation.
...@@ -407,6 +432,7 @@ else: ...@@ -407,6 +432,7 @@ else:
total_steps = 1 total_steps = 1
distillation = False distillation = False
teacher_model = create_finetuned_model()
optimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8) optimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int): def lr_lambda(current_step: int):
...@@ -414,25 +440,26 @@ def lr_lambda(current_step: int): ...@@ -414,25 +440,26 @@ def lr_lambda(current_step: int):
lr_scheduler = LambdaLR(optimizer, lr_lambda) lr_scheduler = LambdaLR(optimizer, lr_lambda)
at_model_save_path = log_dir / 'attention_pruned_model_state.pth' at_model_save_path = log_dir / 'attention_pruned_model_state.pth'
training(train_dataloader, attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, training(attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=total_epochs,
max_epochs=total_epochs, max_steps=total_steps, save_best_model=True, save_path=at_model_save_path, max_steps=total_steps, train_dataloader=train_dataloader, distillation=distillation, teacher_model=teacher_model,
distillation=distillation, evaluation_func=evaluation_func) distil_func=distil_func, log_path=log_dir / 'retraining.log', save_best_model=True, save_path=at_model_save_path,
evaluation_func=evaluation_func, device=device)
if not dev_mode: if not dev_mode:
attention_pruned_model.load_state_dict(torch.load(at_model_save_path)) attention_pruned_model.load_state_dict(torch.load(at_model_save_path))
# %% # %%
# Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations. # Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations.
# Finetuning 2000 steps after each iteration, then finetuning 2 epochs after pruning finished. # Finetuning 3000 steps after each pruning iteration, then finetuning 2 epochs after pruning finished.
# #
# NNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code. # NNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.
if not dev_mode: if not dev_mode:
total_epochs = 4 total_epochs = 7
total_steps = None total_steps = None
taylor_pruner_steps = 1000 taylor_pruner_steps = 1000
steps_per_iteration = 2000 steps_per_iteration = 3000
total_pruning_steps = 24000 total_pruning_steps = 36000
distillation = True distillation = True
else: else:
total_epochs = 1 total_epochs = 1
...@@ -445,8 +472,8 @@ else: ...@@ -445,8 +472,8 @@ else:
from nni.compression.pytorch.pruning import TaylorFOWeightPruner from nni.compression.pytorch.pruning import TaylorFOWeightPruner
from nni.compression.pytorch.speedup import ModelSpeedup from nni.compression.pytorch.speedup import ModelSpeedup
distil_training = functools.partial(training, train_dataloader, log_path=log_dir / 'taylor_pruning.log', distil_training = functools.partial(training, train_dataloader=train_dataloader, distillation=distillation,
distillation=distillation, evaluation_func=evaluation_func) teacher_model=teacher_model, distil_func=distil_func, device=device)
traced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8) traced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
evaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion) evaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)
...@@ -461,7 +488,7 @@ for current_epoch in range(total_epochs): ...@@ -461,7 +488,7 @@ for current_epoch in range(total_epochs):
for batch in train_dataloader: for batch in train_dataloader:
if total_steps and current_step >= total_steps: if total_steps and current_step >= total_steps:
break break
# pruning 12 times # pruning with TaylorFOWeightPruner & reinitialize optimizer
if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps: if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:
check_point = attention_pruned_model.state_dict() check_point = attention_pruned_model.state_dict()
pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps) pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)
...@@ -476,9 +503,6 @@ for current_epoch in range(total_epochs): ...@@ -476,9 +503,6 @@ for current_epoch in range(total_epochs):
optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr) optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)
batch.to(device) batch.to(device)
teacher_logits = batch.pop('teacher_logits', None)
optimizer.zero_grad()
# manually schedule lr # manually schedule lr
for params_group in optimizer.param_groups: for params_group in optimizer.param_groups:
params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr
...@@ -487,14 +511,19 @@ for current_epoch in range(total_epochs): ...@@ -487,14 +511,19 @@ for current_epoch in range(total_epochs):
loss = outputs.loss loss = outputs.loss
# distillation # distillation
if teacher_logits is not None: if distillation:
distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1), assert teacher_model is not None
F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2) with torch.no_grad():
teacher_outputs = teacher_model(**batch)
distil_loss = distil_func(outputs, teacher_outputs)
loss = 0.1 * loss + 0.9 * distil_loss loss = 0.1 * loss + 0.9 * distil_loss
optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
current_step += 1 current_step += 1
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0: if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
result = evaluation_func(attention_pruned_model) result = evaluation_func(attention_pruned_model)
with (log_dir / 'ffn_pruning.log').open('a+') as f: with (log_dir / 'ffn_pruning.log').open('a+') as f:
...@@ -533,28 +562,28 @@ for current_epoch in range(total_epochs): ...@@ -533,28 +562,28 @@ for current_epoch in range(total_epochs):
# - +0.0 / +0.0 # - +0.0 / +0.0
# - 12.56s (x1.00) # - 12.56s (x1.00)
# - 4.05s (x1.00) # - 4.05s (x1.00)
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=5) # * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=5)
# - :ref:`taylor-fo-weight-pruner` # - :ref:`taylor-fo-weight-pruner`
# - 51.39% # - 51.39%
# - 84.25 / 84.96 # - 84.25 / 84.96
# - -0.48 / +0.33 # - -0.48 / +0.33
# - 6.85s (x1.83) # - 6.85s (x1.83)
# - 2.7s (x1.50) # - 2.7s (x1.50)
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=10) # * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=10)
# - :ref:`taylor-fo-weight-pruner` # - :ref:`taylor-fo-weight-pruner`
# - 66.67% # - 66.67%
# - 83.98 / 83.75 # - 83.98 / 83.75
# - -0.75 / -0.88 # - -0.75 / -0.88
# - 4.73s (x2.66) # - 4.73s (x2.66)
# - 2.16s (x1.86) # - 2.16s (x1.86)
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=20) # * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=20)
# - :ref:`taylor-fo-weight-pruner` # - :ref:`taylor-fo-weight-pruner`
# - 77.78% # - 77.78%
# - 83.02 / 83.06 # - 83.02 / 83.06
# - -1.71 / -1.57 # - -1.71 / -1.57
# - 3.35s (x3.75) # - 3.35s (x3.75)
# - 1.72s (x2.35) # - 1.72s (x2.35)
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=30) # * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=30)
# - :ref:`taylor-fo-weight-pruner` # - :ref:`taylor-fo-weight-pruner`
# - 87.04% # - 87.04%
# - 81.24 / 80.99 # - 81.24 / 80.99
......
7d8ff24fe5a88d208ad2ad051f060df4 4935f5727dd073c91bcfab8b9f0676d7
\ No newline at end of file \ No newline at end of file
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
.. _sphx_glr_tutorials_pruning_bert_glue.py: .. _sphx_glr_tutorials_pruning_bert_glue.py:
Pruning Transformer with NNI Pruning Bert on Task MNLI
============================ =========================
Workable Pruning Process Workable Pruning Process
------------------------ ------------------------
...@@ -51,13 +51,19 @@ During the process of pruning transformer, we gained some of the following exper ...@@ -51,13 +51,19 @@ During the process of pruning transformer, we gained some of the following exper
Experiment Experiment
---------- ----------
The complete pruning process will take about 8 hours on one A100.
Preparation Preparation
^^^^^^^^^^^ ^^^^^^^^^^^
Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.
The complete pruning process takes about 8 hours on one A100. This section is mainly to get a finetuned model on the downstream task.
If you are familiar with how to finetune Bert on GLUE dataset, you can skip this section.
.. note::
.. GENERATED FROM PYTHON SOURCE LINES 41-44 Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.
.. GENERATED FROM PYTHON SOURCE LINES 48-51
.. code-block:: default .. code-block:: default
...@@ -71,21 +77,21 @@ The complete pruning process takes about 8 hours on one A100. ...@@ -71,21 +77,21 @@ The complete pruning process takes about 8 hours on one A100.
.. GENERATED FROM PYTHON SOURCE LINES 45-46 .. GENERATED FROM PYTHON SOURCE LINES 52-53
Some basic setting. Some basic setting.
.. GENERATED FROM PYTHON SOURCE LINES 46-72 .. GENERATED FROM PYTHON SOURCE LINES 53-84
.. code-block:: default .. code-block:: default
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable, Dict
pretrained_model_name_or_path = 'bert-base-uncased' pretrained_model_name_or_path = 'bert-base-uncased'
task_name = 'mnli' task_name = 'mnli'
experiment_id = 'pruning_bert' experiment_id = 'pruning_bert_mnli'
# heads_num and layers_num should align with pretrained_model_name_or_path # heads_num and layers_num should align with pretrained_model_name_or_path
heads_num = 12 heads_num = 12
...@@ -99,6 +105,11 @@ Some basic setting. ...@@ -99,6 +105,11 @@ Some basic setting.
model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}') model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')
model_dir.mkdir(parents=True, exist_ok=True) model_dir.mkdir(parents=True, exist_ok=True)
# used to save GLUE data
data_dir = Path(f'./data')
data_dir.mkdir(parents=True, exist_ok=True)
# set seed
from transformers import set_seed from transformers import set_seed
set_seed(1024) set_seed(1024)
...@@ -112,12 +123,11 @@ Some basic setting. ...@@ -112,12 +123,11 @@ Some basic setting.
.. GENERATED FROM PYTHON SOURCE LINES 73-75 .. GENERATED FROM PYTHON SOURCE LINES 85-86
The function used to create dataloaders, note that 'mnli' has two evaluation dataset. Create dataloaders.
If teacher_model is set, will run all dataset on teacher model to get the 'teacher_logits' for distillation.
.. GENERATED FROM PYTHON SOURCE LINES 75-157 .. GENERATED FROM PYTHON SOURCE LINES 86-152
.. code-block:: default .. code-block:: default
...@@ -139,8 +149,7 @@ If teacher_model is set, will run all dataset on teacher model to get the 'teach ...@@ -139,8 +149,7 @@ If teacher_model is set, will run all dataset on teacher model to get the 'teach
'wnli': ('sentence1', 'sentence2'), 'wnli': ('sentence1', 'sentence2'),
} }
def prepare_data(cache_dir='./data', train_batch_size=32, eval_batch_size=32, def prepare_dataloaders(cache_dir=data_dir, train_batch_size=32, eval_batch_size=32):
teacher_model: torch.nn.Module = None):
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path) tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)
sentence1_key, sentence2_key = task_to_keys[task_name] sentence1_key, sentence2_key = task_to_keys[task_name]
data_collator = DataCollatorWithPadding(tokenizer) data_collator = DataCollatorWithPadding(tokenizer)
...@@ -166,137 +175,141 @@ If teacher_model is set, will run all dataset on teacher model to get the 'teach ...@@ -166,137 +175,141 @@ If teacher_model is set, will run all dataset on teacher model to get the 'teach
processed_datasets = raw_datasets.map(preprocess_function, batched=True, processed_datasets = raw_datasets.map(preprocess_function, batched=True,
remove_columns=raw_datasets['train'].column_names) remove_columns=raw_datasets['train'].column_names)
# if has teacher model, add 'teacher_logits' to datasets who has 'labels'.
# 'teacher_logits' is used for distillation and avoid the double counting.
if teacher_model:
teacher_model_training = teacher_model.training
teacher_model.eval()
model_device = next(teacher_model.parameters()).device
def add_teacher_logits(examples):
result = {k: v for k, v in examples.items()}
samples = data_collator(result).to(model_device)
if 'labels' in samples:
with torch.no_grad():
logits = teacher_model(**samples).logits.tolist()
result['teacher_logits'] = logits
return result
processed_datasets = processed_datasets.map(add_teacher_logits, batched=True,
batch_size=train_batch_size)
teacher_model.train(teacher_model_training)
train_dataset = processed_datasets['train'] train_dataset = processed_datasets['train']
validation_dataset = processed_datasets['validation_matched' if task_name == 'mnli' else 'validation'] if task_name == 'mnli':
validation_dataset2 = processed_datasets['validation_mismatched'] if task_name == 'mnli' else None validation_datasets = {
'validation_matched': processed_datasets['validation_matched'],
train_dataloader = DataLoader(train_dataset, 'validation_mismatched': processed_datasets['validation_mismatched']
shuffle=True, }
collate_fn=data_collator, else:
batch_size=train_batch_size) validation_datasets = {
validation_dataloader = DataLoader(validation_dataset, 'validation': processed_datasets['validation']
collate_fn=data_collator, }
batch_size=eval_batch_size)
validation_dataloader2 = DataLoader(validation_dataset2,
collate_fn=data_collator,
batch_size=eval_batch_size) if task_name == 'mnli' else None
return train_dataloader, validation_dataloader, validation_dataloader2
train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
validation_dataloaders = {
val_name: DataLoader(val_dataset, collate_fn=data_collator, batch_size=eval_batch_size) \
for val_name, val_dataset in validation_datasets.items()
}
return train_dataloader, validation_dataloaders
train_dataloader, validation_dataloaders = prepare_dataloaders()
.. GENERATED FROM PYTHON SOURCE LINES 158-159 .. GENERATED FROM PYTHON SOURCE LINES 153-154
Training function & evaluation function. Training function & evaluation function.
.. GENERATED FROM PYTHON SOURCE LINES 159-258 .. GENERATED FROM PYTHON SOURCE LINES 154-277
.. code-block:: default .. code-block:: default
import functools
import time import time
import torch.nn.functional as F import torch.nn.functional as F
from datasets import load_metric from datasets import load_metric
from transformers.modeling_outputs import SequenceClassifierOutput
def training(train_dataloader: DataLoader,
model: torch.nn.Module, def training(model: torch.nn.Module,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
max_steps: int = None, max_epochs: int = None, max_steps: int = None,
save_best_model: bool = False, save_path: str = None, max_epochs: int = None,
log_path: str = Path(log_dir) / 'training.log', train_dataloader: DataLoader = None,
distillation: bool = False, distillation: bool = False,
evaluation_func=None): teacher_model: torch.nn.Module = None,
distil_func: Callable = None,
log_path: str = Path(log_dir) / 'training.log',
save_best_model: bool = False,
save_path: str = None,
evaluation_func: Callable = None,
eval_per_steps: int = 1000,
device=None):
assert train_dataloader is not None
model.train() model.train()
if teacher_model is not None:
teacher_model.eval()
current_step = 0 current_step = 0
best_result = 0 best_result = 0
for current_epoch in range(max_epochs if max_epochs else 1): total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3
total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)
print(f'Training {total_epochs} epochs, {total_steps} steps...')
for current_epoch in range(total_epochs):
for batch in train_dataloader: for batch in train_dataloader:
if current_step >= total_steps:
return
batch.to(device) batch.to(device)
teacher_logits = batch.pop('teacher_logits', None)
optimizer.zero_grad()
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
if distillation: if distillation:
assert teacher_logits is not None assert teacher_model is not None
distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1), with torch.no_grad():
F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2) teacher_outputs = teacher_model(**batch)
distil_loss = distil_func(outputs, teacher_outputs)
loss = 0.1 * loss + 0.9 * distil_loss loss = 0.1 * loss + 0.9 * distil_loss
loss = criterion(loss, None) loss = criterion(loss, None)
optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# per step schedule
if lr_scheduler: if lr_scheduler:
lr_scheduler.step() lr_scheduler.step()
current_step += 1 current_step += 1
# evaluation for every 1000 steps if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
result = evaluation_func(model) if evaluation_func else None result = evaluation_func(model) if evaluation_func else None
with (log_path).open('a+') as f: with (log_path).open('a+') as f:
msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result) msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)
f.write(msg) f.write(msg)
# if it's the best model, save it. # if it's the best model, save it.
if save_best_model and best_result < result['default']: if save_best_model and (result is None or best_result < result['default']):
assert save_path is not None assert save_path is not None
torch.save(model.state_dict(), save_path) torch.save(model.state_dict(), save_path)
best_result = result['default'] best_result = None if result is None else result['default']
if max_steps and current_step >= max_steps:
return
def evaluation(validation_dataloader: DataLoader, def distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):
validation_dataloader2: DataLoader, encoder_hidden_state_loss = []
model: torch.nn.Module): for i, idx in enumerate(encoder_layer_idxs[:-1]):
encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))
logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)
distil_loss = 0
for loss in encoder_hidden_state_loss:
distil_loss += loss
distil_loss += logits_loss
return distil_loss
def evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):
assert validation_dataloaders is not None
training = model.training training = model.training
model.eval() model.eval()
is_regression = task_name == 'stsb' is_regression = task_name == 'stsb'
metric = load_metric('glue', task_name) metric = load_metric('glue', task_name)
result = {}
default_result = 0
for val_name, validation_dataloader in validation_dataloaders.items():
for batch in validation_dataloader: for batch in validation_dataloader:
batch.pop('teacher_logits', None)
batch.to(device)
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
metric.add_batch(
predictions=predictions,
references=batch['labels'],
)
result = metric.compute()
if validation_dataloader2:
for batch in validation_dataloader2:
batch.pop('teacher_logits', None)
batch.to(device) batch.to(device)
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
...@@ -304,19 +317,19 @@ Training function & evaluation function. ...@@ -304,19 +317,19 @@ Training function & evaluation function.
predictions=predictions, predictions=predictions,
references=batch['labels'], references=batch['labels'],
) )
result = {'matched': result, 'mismatched': metric.compute()} result[val_name] = metric.compute()
result['default'] = (result['matched']['accuracy'] + result['mismatched']['accuracy']) / 2 default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))
else: result['default'] = default_result / len(result)
result['default'] = result.get('f1', result.get('accuracy', None))
model.train(training) model.train(training)
return result return result
# using huggingface native loss
def fake_criterion(outputs, targets): evaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)
return outputs
def fake_criterion(loss, _):
return loss
...@@ -324,116 +337,70 @@ Training function & evaluation function. ...@@ -324,116 +337,70 @@ Training function & evaluation function.
.. GENERATED FROM PYTHON SOURCE LINES 259-260
.. GENERATED FROM PYTHON SOURCE LINES 278-279
Prepare pre-trained model and finetuning on downstream task. Prepare pre-trained model and finetuning on downstream task.
.. GENERATED FROM PYTHON SOURCE LINES 260-299 .. GENERATED FROM PYTHON SOURCE LINES 279-320
.. code-block:: default .. code-block:: default
import functools
from torch.optim import Adam from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from transformers import BertForSequenceClassification from transformers import BertForSequenceClassification
def create_pretrained_model(): def create_pretrained_model():
is_regression = task_name == 'stsb' is_regression = task_name == 'stsb'
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2) num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels) model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
model.bert.config.output_hidden_states = True
return model
def create_finetuned_model():
pretrained_model = create_pretrained_model().to(device)
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()
evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)
steps_per_epoch = len(train_dataloader)
training_epochs = 3
def create_finetuned_model():
finetuned_model = create_pretrained_model()
finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth' finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'
if finetuned_model_state_path.exists(): if finetuned_model_state_path.exists():
pretrained_model.load_state_dict(torch.load(finetuned_model_state_path)) finetuned_model.load_state_dict(torch.load(finetuned_model_state_path, map_location='cpu'))
finetuned_model.to(device)
elif dev_mode: elif dev_mode:
pass pass
else: else:
optimizer = Adam(pretrained_model.parameters(), lr=3e-5, eps=1e-8) steps_per_epoch = len(train_dataloader)
training_epochs = 3
optimizer = Adam(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int): def lr_lambda(current_step: int):
return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch)) return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))
lr_scheduler = LambdaLR(optimizer, lr_lambda) lr_scheduler = LambdaLR(optimizer, lr_lambda)
training(train_dataloader, pretrained_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=training_epochs, training(finetuned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,
save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func) max_epochs=training_epochs, train_dataloader=train_dataloader, log_path=log_dir / 'finetuning_on_downstream.log',
return pretrained_model save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func, device=device)
return finetuned_model
finetuned_model = create_finetuned_model()
finetuned_model = create_finetuned_model()
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Reusing dataset glue (./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
0%| | 0/5 [00:00<?, ?it/s] 100%|##########| 5/5 [00:00<00:00, 1213.84it/s]
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9c32a3d5eca55607.arrow
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-6f0849c5f6325016.arrow
0%| | 0/10 [00:00<?, ?ba/s] 40%|#### | 4/10 [00:00<00:00, 34.52ba/s] 90%|######### | 9/10 [00:00<00:00, 38.77ba/s] 100%|##########| 10/10 [00:00<00:00, 38.78ba/s]
.. GENERATED FROM PYTHON SOURCE LINES 300-302
Using finetuned model as teacher model to create dataloader.
Add 'teacher_logits' to dataset, it is used to do the distillation, it can be seen as a kind of data label.
.. GENERATED FROM PYTHON SOURCE LINES 302-310
.. code-block:: default
if not dev_mode:
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data(teacher_model=finetuned_model)
else:
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()
evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Reusing dataset glue (./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
0%| | 0/5 [00:00<?, ?it/s] 100%|##########| 5/5 [00:00<00:00, 1249.79it/s]
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9c32a3d5eca55607.arrow
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-6f0849c5f6325016.arrow
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-5db72911f5dfb448.arrow
.. GENERATED FROM PYTHON SOURCE LINES 311-314 .. GENERATED FROM PYTHON SOURCE LINES 321-328
Pruning Pruning
^^^^^^^ ^^^^^^^
First, using MovementPruner to prune attention head. According to experience, it is easier to achieve good results by pruning the attention part and the FFN part in stages.
Of course, pruning together can also achieve the similar effect, but more parameter adjustment attempts are required.
So in this section, we do pruning in stages.
.. GENERATED FROM PYTHON SOURCE LINES 314-367 First, we prune the attention layer with MovementPruner.
.. GENERATED FROM PYTHON SOURCE LINES 328-388
.. code-block:: default .. code-block:: default
...@@ -458,8 +425,9 @@ First, using MovementPruner to prune attention head. ...@@ -458,8 +425,9 @@ First, using MovementPruner to prune attention head.
import nni import nni
from nni.algorithms.compression.v2.pytorch import TorchEvaluator from nni.algorithms.compression.v2.pytorch import TorchEvaluator
movement_training = functools.partial(training, train_dataloader, log_path=log_dir / 'movement_pruning.log', movement_training = functools.partial(training, train_dataloader=train_dataloader,
evaluation_func=evaluation_func) log_path=log_dir / 'movement_pruning.log',
evaluation_func=evaluation_func, device=device)
traced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8) traced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int): def lr_lambda(current_step: int):
...@@ -471,10 +439,16 @@ First, using MovementPruner to prune attention head. ...@@ -471,10 +439,16 @@ First, using MovementPruner to prune attention head.
evaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler) evaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)
# Apply block-soft-movement pruning on attention layers. # Apply block-soft-movement pruning on attention layers.
# Note that block sparse is introduced by `sparse_granularity='auto'`, and only support `bert`, `bart`, `t5` right now.
from nni.compression.pytorch.pruning import MovementPruner from nni.compression.pytorch.pruning import MovementPruner
config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder.layer.{}.'.format(i) for i in range(layers_num)], 'sparsity': 0.1}] config_list = [{
'op_types': ['Linear'],
'op_partial_names': ['bert.encoder.layer.{}.attention'.format(i) for i in range(layers_num)],
'sparsity': 0.1
}]
pruner = MovementPruner(model=finetuned_model, pruner = MovementPruner(model=finetuned_model,
config_list=config_list, config_list=config_list,
evaluator=evaluator, evaluator=evaluator,
...@@ -493,25 +467,14 @@ First, using MovementPruner to prune attention head. ...@@ -493,25 +467,14 @@ First, using MovementPruner to prune attention head.
.. GENERATED FROM PYTHON SOURCE LINES 389-393
.. rst-class:: sphx-glr-script-out Load a new finetuned model to do speedup, you can think of this as using the finetuned state to initialize the pruned model weights.
Note that nni speedup don't support replacing attention module, so here we manully replace the attention module.
.. code-block:: none
Did not bind any model, no need to unbind model.
Did not bind any model, no need to unbind model.
.. GENERATED FROM PYTHON SOURCE LINES 368-372
Load a new finetuned model to do the speedup.
Note that nni speedup don't support replace attention module, so here we manully replace the attention module.
If the head is entire masked, physically prune it and create config_list for FFN pruning. If the head is entire masked, physically prune it and create config_list for FFN pruning.
.. GENERATED FROM PYTHON SOURCE LINES 372-398 .. GENERATED FROM PYTHON SOURCE LINES 393-423
.. code-block:: default .. code-block:: default
...@@ -520,66 +483,39 @@ If the head is entire masked, physically prune it and create config_list for FFN ...@@ -520,66 +483,39 @@ If the head is entire masked, physically prune it and create config_list for FFN
attention_masks = torch.load(Path(log_dir) / 'attention_masks.pth') attention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')
ffn_config_list = [] ffn_config_list = []
layer_count = 0 layer_remained_idxs = []
module_list = [] module_list = []
for i in range(0, layers_num): for i in range(0, layers_num):
prefix = f'bert.encoder.layer.{i}.' prefix = f'bert.encoder.layer.{i}.'
value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight'] value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.) head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
head_idx = torch.arange(len(head_mask))[head_mask].long().tolist() head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()
print(f'layer {i} pruner {len(head_idx)} head: {head_idx}') print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')
if len(head_idx) != heads_num: if len(head_idxs) != heads_num:
attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idx) attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)
module_list.append(attention_pruned_model.bert.encoder.layer[i]) module_list.append(attention_pruned_model.bert.encoder.layer[i])
# The final ffn weight remaining ratio is the half of the attention weight remaining ratio. # The final ffn weight remaining ratio is the half of the attention weight remaining ratio.
# This is just an empirical configuration, you can use any other method to determine this sparsity. # This is just an empirical configuration, you can use any other method to determine this sparsity.
sparsity = 1 - (1 - len(head_idx) / heads_num) * 0.5 sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5
# here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`. # here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.
sparsity_per_iter = 1 - (1 - sparsity) ** (1 / heads_num) sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)
ffn_config_list.append({'op_names': [f'bert.encoder.layer.{layer_count}.intermediate.dense'], 'sparsity': sparsity_per_iter}) ffn_config_list.append({
layer_count += 1 'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],
'sparsity': sparsity_per_iter
})
layer_remained_idxs.append(i)
attention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list) attention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)
distil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)
.. GENERATED FROM PYTHON SOURCE LINES 424-425
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Reusing dataset glue (./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
0%| | 0/5 [00:00<?, ?it/s] 100%|##########| 5/5 [00:00<00:00, 1141.12it/s]
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9c32a3d5eca55607.arrow
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-6f0849c5f6325016.arrow
Loading cached processed dataset at ./data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-5db72911f5dfb448.arrow
layer 0 pruner 0 head: []
layer 1 pruner 0 head: []
layer 2 pruner 0 head: []
layer 3 pruner 0 head: []
layer 4 pruner 0 head: []
layer 5 pruner 0 head: []
layer 6 pruner 0 head: []
layer 7 pruner 0 head: []
layer 8 pruner 0 head: []
layer 9 pruner 0 head: []
layer 10 pruner 0 head: []
layer 11 pruner 0 head: []
.. GENERATED FROM PYTHON SOURCE LINES 399-400
Retrain the attention pruned model with distillation. Retrain the attention pruned model with distillation.
.. GENERATED FROM PYTHON SOURCE LINES 400-424 .. GENERATED FROM PYTHON SOURCE LINES 425-451
.. code-block:: default .. code-block:: default
...@@ -593,6 +529,7 @@ Retrain the attention pruned model with distillation. ...@@ -593,6 +529,7 @@ Retrain the attention pruned model with distillation.
total_steps = 1 total_steps = 1
distillation = False distillation = False
teacher_model = create_finetuned_model()
optimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8) optimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int): def lr_lambda(current_step: int):
...@@ -600,9 +537,10 @@ Retrain the attention pruned model with distillation. ...@@ -600,9 +537,10 @@ Retrain the attention pruned model with distillation.
lr_scheduler = LambdaLR(optimizer, lr_lambda) lr_scheduler = LambdaLR(optimizer, lr_lambda)
at_model_save_path = log_dir / 'attention_pruned_model_state.pth' at_model_save_path = log_dir / 'attention_pruned_model_state.pth'
training(train_dataloader, attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, training(attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=total_epochs,
max_epochs=total_epochs, max_steps=total_steps, save_best_model=True, save_path=at_model_save_path, max_steps=total_steps, train_dataloader=train_dataloader, distillation=distillation, teacher_model=teacher_model,
distillation=distillation, evaluation_func=evaluation_func) distil_func=distil_func, log_path=log_dir / 'retraining.log', save_best_model=True, save_path=at_model_save_path,
evaluation_func=evaluation_func, device=device)
if not dev_mode: if not dev_mode:
attention_pruned_model.load_state_dict(torch.load(at_model_save_path)) attention_pruned_model.load_state_dict(torch.load(at_model_save_path))
...@@ -610,28 +548,24 @@ Retrain the attention pruned model with distillation. ...@@ -610,28 +548,24 @@ Retrain the attention pruned model with distillation.
.. GENERATED FROM PYTHON SOURCE LINES 452-456
.. GENERATED FROM PYTHON SOURCE LINES 425-429
Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations. Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations.
Finetuning 2000 steps after each iteration, then finetuning 2 epochs after pruning finished. Finetuning 3000 steps after each pruning iteration, then finetuning 2 epochs after pruning finished.
NNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code. NNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.
.. GENERATED FROM PYTHON SOURCE LINES 429-508 .. GENERATED FROM PYTHON SOURCE LINES 456-537
.. code-block:: default .. code-block:: default
if not dev_mode: if not dev_mode:
total_epochs = 4 total_epochs = 7
total_steps = None total_steps = None
taylor_pruner_steps = 1000 taylor_pruner_steps = 1000
steps_per_iteration = 2000 steps_per_iteration = 3000
total_pruning_steps = 24000 total_pruning_steps = 36000
distillation = True distillation = True
else: else:
total_epochs = 1 total_epochs = 1
...@@ -644,8 +578,8 @@ NNI will support per-step-pruning-schedule in the future, then can use an pruner ...@@ -644,8 +578,8 @@ NNI will support per-step-pruning-schedule in the future, then can use an pruner
from nni.compression.pytorch.pruning import TaylorFOWeightPruner from nni.compression.pytorch.pruning import TaylorFOWeightPruner
from nni.compression.pytorch.speedup import ModelSpeedup from nni.compression.pytorch.speedup import ModelSpeedup
distil_training = functools.partial(training, train_dataloader, log_path=log_dir / 'taylor_pruning.log', distil_training = functools.partial(training, train_dataloader=train_dataloader, distillation=distillation,
distillation=distillation, evaluation_func=evaluation_func) teacher_model=teacher_model, distil_func=distil_func, device=device)
traced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8) traced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
evaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion) evaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)
...@@ -660,7 +594,7 @@ NNI will support per-step-pruning-schedule in the future, then can use an pruner ...@@ -660,7 +594,7 @@ NNI will support per-step-pruning-schedule in the future, then can use an pruner
for batch in train_dataloader: for batch in train_dataloader:
if total_steps and current_step >= total_steps: if total_steps and current_step >= total_steps:
break break
# pruning 12 times # pruning with TaylorFOWeightPruner & reinitialize optimizer
if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps: if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:
check_point = attention_pruned_model.state_dict() check_point = attention_pruned_model.state_dict()
pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps) pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)
...@@ -675,9 +609,6 @@ NNI will support per-step-pruning-schedule in the future, then can use an pruner ...@@ -675,9 +609,6 @@ NNI will support per-step-pruning-schedule in the future, then can use an pruner
optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr) optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)
batch.to(device) batch.to(device)
teacher_logits = batch.pop('teacher_logits', None)
optimizer.zero_grad()
# manually schedule lr # manually schedule lr
for params_group in optimizer.param_groups: for params_group in optimizer.param_groups:
params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr
...@@ -686,14 +617,19 @@ NNI will support per-step-pruning-schedule in the future, then can use an pruner ...@@ -686,14 +617,19 @@ NNI will support per-step-pruning-schedule in the future, then can use an pruner
loss = outputs.loss loss = outputs.loss
# distillation # distillation
if teacher_logits is not None: if distillation:
distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1), assert teacher_model is not None
F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2) with torch.no_grad():
teacher_outputs = teacher_model(**batch)
distil_loss = distil_func(outputs, teacher_outputs)
loss = 0.1 * loss + 0.9 * distil_loss loss = 0.1 * loss + 0.9 * distil_loss
optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
current_step += 1 current_step += 1
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0: if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
result = evaluation_func(attention_pruned_model) result = evaluation_func(attention_pruned_model)
with (log_dir / 'ffn_pruning.log').open('a+') as f: with (log_dir / 'ffn_pruning.log').open('a+') as f:
...@@ -707,22 +643,7 @@ NNI will support per-step-pruning-schedule in the future, then can use an pruner ...@@ -707,22 +643,7 @@ NNI will support per-step-pruning-schedule in the future, then can use an pruner
.. GENERATED FROM PYTHON SOURCE LINES 538-593
.. rst-class:: sphx-glr-script-out
.. code-block:: none
Did not bind any model, no need to unbind model.
no multi-dimension masks found.
/home/nishang/anaconda3/envs/nni-dev/lib/python3.7/site-packages/torch/_tensor.py:1083: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:477.)
return self._grad
Did not bind any model, no need to unbind model.
no multi-dimension masks found.
.. GENERATED FROM PYTHON SOURCE LINES 509-564
Result Result
------ ------
...@@ -751,28 +672,28 @@ Setting 2: pytorch 1.10.0 ...@@ -751,28 +672,28 @@ Setting 2: pytorch 1.10.0
- +0.0 / +0.0 - +0.0 / +0.0
- 12.56s (x1.00) - 12.56s (x1.00)
- 4.05s (x1.00) - 4.05s (x1.00)
* - :ref:`movement-pruner` (soft, th=0.1, lambda=5) * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=5)
- :ref:`taylor-fo-weight-pruner` - :ref:`taylor-fo-weight-pruner`
- 51.39% - 51.39%
- 84.25 / 84.96 - 84.25 / 84.96
- -0.48 / +0.33 - -0.48 / +0.33
- 6.85s (x1.83) - 6.85s (x1.83)
- 2.7s (x1.50) - 2.7s (x1.50)
* - :ref:`movement-pruner` (soft, th=0.1, lambda=10) * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=10)
- :ref:`taylor-fo-weight-pruner` - :ref:`taylor-fo-weight-pruner`
- 66.67% - 66.67%
- 83.98 / 83.75 - 83.98 / 83.75
- -0.75 / -0.88 - -0.75 / -0.88
- 4.73s (x2.66) - 4.73s (x2.66)
- 2.16s (x1.86) - 2.16s (x1.86)
* - :ref:`movement-pruner` (soft, th=0.1, lambda=20) * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=20)
- :ref:`taylor-fo-weight-pruner` - :ref:`taylor-fo-weight-pruner`
- 77.78% - 77.78%
- 83.02 / 83.06 - 83.02 / 83.06
- -1.71 / -1.57 - -1.71 / -1.57
- 3.35s (x3.75) - 3.35s (x3.75)
- 1.72s (x2.35) - 1.72s (x2.35)
* - :ref:`movement-pruner` (soft, th=0.1, lambda=30) * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=30)
- :ref:`taylor-fo-weight-pruner` - :ref:`taylor-fo-weight-pruner`
- 87.04% - 87.04%
- 81.24 / 80.99 - 81.24 / 80.99
...@@ -783,7 +704,7 @@ Setting 2: pytorch 1.10.0 ...@@ -783,7 +704,7 @@ Setting 2: pytorch 1.10.0
.. rst-class:: sphx-glr-timing .. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 27.206 seconds) **Total running time of the script:** ( 0 minutes 41.637 seconds)
.. _sphx_glr_download_tutorials_pruning_bert_glue.py: .. _sphx_glr_download_tutorials_pruning_bert_glue.py:
......
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
Computation times Computation times
================= =================
**01:38.004** total execution time for **tutorials** files: **00:41.637** total execution time for **tutorials** files:
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_darts.py` (``darts.py``) | 01:38.004 | 0.0 MB | | :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:41.637 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:27.206 | 0.0 MB | | :ref:`sphx_glr_tutorials_darts.py` (``darts.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_hello_nas.py` (``hello_nas.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_hello_nas.py` (``hello_nas.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
......
""" """
Pruning Transformer with NNI Pruning Bert on Task MNLI
============================ =========================
Workable Pruning Process Workable Pruning Process
------------------------ ------------------------
...@@ -32,11 +32,18 @@ During the process of pruning transformer, we gained some of the following exper ...@@ -32,11 +32,18 @@ During the process of pruning transformer, we gained some of the following exper
Experiment Experiment
---------- ----------
The complete pruning process will take about 8 hours on one A100.
Preparation Preparation
^^^^^^^^^^^ ^^^^^^^^^^^
Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.
The complete pruning process takes about 8 hours on one A100. This section is mainly to get a finetuned model on the downstream task.
If you are familiar with how to finetune Bert on GLUE dataset, you can skip this section.
.. note::
Please set ``dev_mode`` to ``False`` to run this tutorial. Here ``dev_mode`` is ``True`` by default is for generating documents.
""" """
dev_mode = True dev_mode = True
...@@ -45,11 +52,11 @@ dev_mode = True ...@@ -45,11 +52,11 @@ dev_mode = True
# Some basic setting. # Some basic setting.
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable, Dict
pretrained_model_name_or_path = 'bert-base-uncased' pretrained_model_name_or_path = 'bert-base-uncased'
task_name = 'mnli' task_name = 'mnli'
experiment_id = 'pruning_bert' experiment_id = 'pruning_bert_mnli'
# heads_num and layers_num should align with pretrained_model_name_or_path # heads_num and layers_num should align with pretrained_model_name_or_path
heads_num = 12 heads_num = 12
...@@ -63,6 +70,11 @@ log_dir.mkdir(parents=True, exist_ok=True) ...@@ -63,6 +70,11 @@ log_dir.mkdir(parents=True, exist_ok=True)
model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}') model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')
model_dir.mkdir(parents=True, exist_ok=True) model_dir.mkdir(parents=True, exist_ok=True)
# used to save GLUE data
data_dir = Path(f'./data')
data_dir.mkdir(parents=True, exist_ok=True)
# set seed
from transformers import set_seed from transformers import set_seed
set_seed(1024) set_seed(1024)
...@@ -70,8 +82,7 @@ import torch ...@@ -70,8 +82,7 @@ import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# %% # %%
# The function used to create dataloaders, note that 'mnli' has two evaluation dataset. # Create dataloaders.
# If teacher_model is set, will run all dataset on teacher model to get the 'teacher_logits' for distillation.
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -90,8 +101,7 @@ task_to_keys = { ...@@ -90,8 +101,7 @@ task_to_keys = {
'wnli': ('sentence1', 'sentence2'), 'wnli': ('sentence1', 'sentence2'),
} }
def prepare_data(cache_dir='./data', train_batch_size=32, eval_batch_size=32, def prepare_dataloaders(cache_dir=data_dir, train_batch_size=32, eval_batch_size=32):
teacher_model: torch.nn.Module = None):
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path) tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)
sentence1_key, sentence2_key = task_to_keys[task_name] sentence1_key, sentence2_key = task_to_keys[task_name]
data_collator = DataCollatorWithPadding(tokenizer) data_collator = DataCollatorWithPadding(tokenizer)
...@@ -117,124 +127,132 @@ def prepare_data(cache_dir='./data', train_batch_size=32, eval_batch_size=32, ...@@ -117,124 +127,132 @@ def prepare_data(cache_dir='./data', train_batch_size=32, eval_batch_size=32,
processed_datasets = raw_datasets.map(preprocess_function, batched=True, processed_datasets = raw_datasets.map(preprocess_function, batched=True,
remove_columns=raw_datasets['train'].column_names) remove_columns=raw_datasets['train'].column_names)
# if has teacher model, add 'teacher_logits' to datasets who has 'labels'. train_dataset = processed_datasets['train']
# 'teacher_logits' is used for distillation and avoid the double counting. if task_name == 'mnli':
if teacher_model: validation_datasets = {
teacher_model_training = teacher_model.training 'validation_matched': processed_datasets['validation_matched'],
teacher_model.eval() 'validation_mismatched': processed_datasets['validation_mismatched']
model_device = next(teacher_model.parameters()).device }
else:
validation_datasets = {
'validation': processed_datasets['validation']
}
def add_teacher_logits(examples): train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
result = {k: v for k, v in examples.items()} validation_dataloaders = {
samples = data_collator(result).to(model_device) val_name: DataLoader(val_dataset, collate_fn=data_collator, batch_size=eval_batch_size) \
if 'labels' in samples: for val_name, val_dataset in validation_datasets.items()
with torch.no_grad(): }
logits = teacher_model(**samples).logits.tolist()
result['teacher_logits'] = logits
return result
processed_datasets = processed_datasets.map(add_teacher_logits, batched=True, return train_dataloader, validation_dataloaders
batch_size=train_batch_size)
teacher_model.train(teacher_model_training)
train_dataset = processed_datasets['train']
validation_dataset = processed_datasets['validation_matched' if task_name == 'mnli' else 'validation'] train_dataloader, validation_dataloaders = prepare_dataloaders()
validation_dataset2 = processed_datasets['validation_mismatched'] if task_name == 'mnli' else None
train_dataloader = DataLoader(train_dataset,
shuffle=True,
collate_fn=data_collator,
batch_size=train_batch_size)
validation_dataloader = DataLoader(validation_dataset,
collate_fn=data_collator,
batch_size=eval_batch_size)
validation_dataloader2 = DataLoader(validation_dataset2,
collate_fn=data_collator,
batch_size=eval_batch_size) if task_name == 'mnli' else None
return train_dataloader, validation_dataloader, validation_dataloader2
# %% # %%
# Training function & evaluation function. # Training function & evaluation function.
import functools
import time import time
import torch.nn.functional as F import torch.nn.functional as F
from datasets import load_metric from datasets import load_metric
from transformers.modeling_outputs import SequenceClassifierOutput
def training(train_dataloader: DataLoader,
model: torch.nn.Module, def training(model: torch.nn.Module,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
max_steps: int = None, max_epochs: int = None, max_steps: int = None,
save_best_model: bool = False, save_path: str = None, max_epochs: int = None,
log_path: str = Path(log_dir) / 'training.log', train_dataloader: DataLoader = None,
distillation: bool = False, distillation: bool = False,
evaluation_func=None): teacher_model: torch.nn.Module = None,
distil_func: Callable = None,
log_path: str = Path(log_dir) / 'training.log',
save_best_model: bool = False,
save_path: str = None,
evaluation_func: Callable = None,
eval_per_steps: int = 1000,
device=None):
assert train_dataloader is not None
model.train() model.train()
if teacher_model is not None:
teacher_model.eval()
current_step = 0 current_step = 0
best_result = 0 best_result = 0
for current_epoch in range(max_epochs if max_epochs else 1): total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3
total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)
print(f'Training {total_epochs} epochs, {total_steps} steps...')
for current_epoch in range(total_epochs):
for batch in train_dataloader: for batch in train_dataloader:
if current_step >= total_steps:
return
batch.to(device) batch.to(device)
teacher_logits = batch.pop('teacher_logits', None)
optimizer.zero_grad()
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
if distillation: if distillation:
assert teacher_logits is not None assert teacher_model is not None
distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1), with torch.no_grad():
F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2) teacher_outputs = teacher_model(**batch)
distil_loss = distil_func(outputs, teacher_outputs)
loss = 0.1 * loss + 0.9 * distil_loss loss = 0.1 * loss + 0.9 * distil_loss
loss = criterion(loss, None) loss = criterion(loss, None)
optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# per step schedule
if lr_scheduler: if lr_scheduler:
lr_scheduler.step() lr_scheduler.step()
current_step += 1 current_step += 1
# evaluation for every 1000 steps if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
result = evaluation_func(model) if evaluation_func else None result = evaluation_func(model) if evaluation_func else None
with (log_path).open('a+') as f: with (log_path).open('a+') as f:
msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result) msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)
f.write(msg) f.write(msg)
# if it's the best model, save it. # if it's the best model, save it.
if save_best_model and best_result < result['default']: if save_best_model and (result is None or best_result < result['default']):
assert save_path is not None assert save_path is not None
torch.save(model.state_dict(), save_path) torch.save(model.state_dict(), save_path)
best_result = result['default'] best_result = None if result is None else result['default']
if max_steps and current_step >= max_steps:
return
def evaluation(validation_dataloader: DataLoader, def distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):
validation_dataloader2: DataLoader, encoder_hidden_state_loss = []
model: torch.nn.Module): for i, idx in enumerate(encoder_layer_idxs[:-1]):
encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))
logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)
distil_loss = 0
for loss in encoder_hidden_state_loss:
distil_loss += loss
distil_loss += logits_loss
return distil_loss
def evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):
assert validation_dataloaders is not None
training = model.training training = model.training
model.eval() model.eval()
is_regression = task_name == 'stsb' is_regression = task_name == 'stsb'
metric = load_metric('glue', task_name) metric = load_metric('glue', task_name)
result = {}
default_result = 0
for val_name, validation_dataloader in validation_dataloaders.items():
for batch in validation_dataloader: for batch in validation_dataloader:
batch.pop('teacher_logits', None)
batch.to(device)
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
metric.add_batch(
predictions=predictions,
references=batch['labels'],
)
result = metric.compute()
if validation_dataloader2:
for batch in validation_dataloader2:
batch.pop('teacher_logits', None)
batch.to(device) batch.to(device)
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
...@@ -242,75 +260,71 @@ def evaluation(validation_dataloader: DataLoader, ...@@ -242,75 +260,71 @@ def evaluation(validation_dataloader: DataLoader,
predictions=predictions, predictions=predictions,
references=batch['labels'], references=batch['labels'],
) )
result = {'matched': result, 'mismatched': metric.compute()} result[val_name] = metric.compute()
result['default'] = (result['matched']['accuracy'] + result['mismatched']['accuracy']) / 2 default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))
else: result['default'] = default_result / len(result)
result['default'] = result.get('f1', result.get('accuracy', None))
model.train(training) model.train(training)
return result return result
# using huggingface native loss
def fake_criterion(outputs, targets):
return outputs
evaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)
def fake_criterion(loss, _):
return loss
# %% # %%
# Prepare pre-trained model and finetuning on downstream task. # Prepare pre-trained model and finetuning on downstream task.
import functools
from torch.optim import Adam from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from transformers import BertForSequenceClassification from transformers import BertForSequenceClassification
def create_pretrained_model(): def create_pretrained_model():
is_regression = task_name == 'stsb' is_regression = task_name == 'stsb'
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2) num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
return BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels) model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
model.bert.config.output_hidden_states = True
return model
def create_finetuned_model():
pretrained_model = create_pretrained_model().to(device)
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()
evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)
steps_per_epoch = len(train_dataloader)
training_epochs = 3
def create_finetuned_model():
finetuned_model = create_pretrained_model()
finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth' finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'
if finetuned_model_state_path.exists(): if finetuned_model_state_path.exists():
pretrained_model.load_state_dict(torch.load(finetuned_model_state_path)) finetuned_model.load_state_dict(torch.load(finetuned_model_state_path, map_location='cpu'))
finetuned_model.to(device)
elif dev_mode: elif dev_mode:
pass pass
else: else:
optimizer = Adam(pretrained_model.parameters(), lr=3e-5, eps=1e-8) steps_per_epoch = len(train_dataloader)
training_epochs = 3
optimizer = Adam(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int): def lr_lambda(current_step: int):
return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch)) return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))
lr_scheduler = LambdaLR(optimizer, lr_lambda) lr_scheduler = LambdaLR(optimizer, lr_lambda)
training(train_dataloader, pretrained_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=training_epochs, training(finetuned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,
save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func) max_epochs=training_epochs, train_dataloader=train_dataloader, log_path=log_dir / 'finetuning_on_downstream.log',
return pretrained_model save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func, device=device)
return finetuned_model
finetuned_model = create_finetuned_model()
# %% finetuned_model = create_finetuned_model()
# Using finetuned model as teacher model to create dataloader.
# Add 'teacher_logits' to dataset, it is used to do the distillation, it can be seen as a kind of data label.
if not dev_mode:
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data(teacher_model=finetuned_model)
else:
train_dataloader, validation_dataloader, validation_dataloader2 = prepare_data()
evaluation_func = functools.partial(evaluation, validation_dataloader, validation_dataloader2)
# %% # %%
# Pruning # Pruning
# ^^^^^^^ # ^^^^^^^
# First, using MovementPruner to prune attention head. # According to experience, it is easier to achieve good results by pruning the attention part and the FFN part in stages.
# Of course, pruning together can also achieve the similar effect, but more parameter adjustment attempts are required.
# So in this section, we do pruning in stages.
#
# First, we prune the attention layer with MovementPruner.
steps_per_epoch = len(train_dataloader) steps_per_epoch = len(train_dataloader)
...@@ -332,8 +346,9 @@ else: ...@@ -332,8 +346,9 @@ else:
import nni import nni
from nni.algorithms.compression.v2.pytorch import TorchEvaluator from nni.algorithms.compression.v2.pytorch import TorchEvaluator
movement_training = functools.partial(training, train_dataloader, log_path=log_dir / 'movement_pruning.log', movement_training = functools.partial(training, train_dataloader=train_dataloader,
evaluation_func=evaluation_func) log_path=log_dir / 'movement_pruning.log',
evaluation_func=evaluation_func, device=device)
traced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8) traced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int): def lr_lambda(current_step: int):
...@@ -345,10 +360,16 @@ traced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda) ...@@ -345,10 +360,16 @@ traced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)
evaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler) evaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)
# Apply block-soft-movement pruning on attention layers. # Apply block-soft-movement pruning on attention layers.
# Note that block sparse is introduced by `sparse_granularity='auto'`, and only support `bert`, `bart`, `t5` right now.
from nni.compression.pytorch.pruning import MovementPruner from nni.compression.pytorch.pruning import MovementPruner
config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder.layer.{}.'.format(i) for i in range(layers_num)], 'sparsity': 0.1}] config_list = [{
'op_types': ['Linear'],
'op_partial_names': ['bert.encoder.layer.{}.attention'.format(i) for i in range(layers_num)],
'sparsity': 0.1
}]
pruner = MovementPruner(model=finetuned_model, pruner = MovementPruner(model=finetuned_model,
config_list=config_list, config_list=config_list,
evaluator=evaluator, evaluator=evaluator,
...@@ -365,8 +386,8 @@ pruner.show_pruned_weights() ...@@ -365,8 +386,8 @@ pruner.show_pruned_weights()
torch.save(attention_masks, Path(log_dir) / 'attention_masks.pth') torch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')
# %% # %%
# Load a new finetuned model to do the speedup. # Load a new finetuned model to do speedup, you can think of this as using the finetuned state to initialize the pruned model weights.
# Note that nni speedup don't support replace attention module, so here we manully replace the attention module. # Note that nni speedup don't support replacing attention module, so here we manully replace the attention module.
# #
# If the head is entire masked, physically prune it and create config_list for FFN pruning. # If the head is entire masked, physically prune it and create config_list for FFN pruning.
...@@ -374,26 +395,30 @@ attention_pruned_model = create_finetuned_model().to(device) ...@@ -374,26 +395,30 @@ attention_pruned_model = create_finetuned_model().to(device)
attention_masks = torch.load(Path(log_dir) / 'attention_masks.pth') attention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')
ffn_config_list = [] ffn_config_list = []
layer_count = 0 layer_remained_idxs = []
module_list = [] module_list = []
for i in range(0, layers_num): for i in range(0, layers_num):
prefix = f'bert.encoder.layer.{i}.' prefix = f'bert.encoder.layer.{i}.'
value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight'] value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.) head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
head_idx = torch.arange(len(head_mask))[head_mask].long().tolist() head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()
print(f'layer {i} pruner {len(head_idx)} head: {head_idx}') print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')
if len(head_idx) != heads_num: if len(head_idxs) != heads_num:
attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idx) attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)
module_list.append(attention_pruned_model.bert.encoder.layer[i]) module_list.append(attention_pruned_model.bert.encoder.layer[i])
# The final ffn weight remaining ratio is the half of the attention weight remaining ratio. # The final ffn weight remaining ratio is the half of the attention weight remaining ratio.
# This is just an empirical configuration, you can use any other method to determine this sparsity. # This is just an empirical configuration, you can use any other method to determine this sparsity.
sparsity = 1 - (1 - len(head_idx) / heads_num) * 0.5 sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5
# here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`. # here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.
sparsity_per_iter = 1 - (1 - sparsity) ** (1 / heads_num) sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)
ffn_config_list.append({'op_names': [f'bert.encoder.layer.{layer_count}.intermediate.dense'], 'sparsity': sparsity_per_iter}) ffn_config_list.append({
layer_count += 1 'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],
'sparsity': sparsity_per_iter
})
layer_remained_idxs.append(i)
attention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list) attention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)
distil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)
# %% # %%
# Retrain the attention pruned model with distillation. # Retrain the attention pruned model with distillation.
...@@ -407,6 +432,7 @@ else: ...@@ -407,6 +432,7 @@ else:
total_steps = 1 total_steps = 1
distillation = False distillation = False
teacher_model = create_finetuned_model()
optimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8) optimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
def lr_lambda(current_step: int): def lr_lambda(current_step: int):
...@@ -414,25 +440,26 @@ def lr_lambda(current_step: int): ...@@ -414,25 +440,26 @@ def lr_lambda(current_step: int):
lr_scheduler = LambdaLR(optimizer, lr_lambda) lr_scheduler = LambdaLR(optimizer, lr_lambda)
at_model_save_path = log_dir / 'attention_pruned_model_state.pth' at_model_save_path = log_dir / 'attention_pruned_model_state.pth'
training(train_dataloader, attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, training(attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=total_epochs,
max_epochs=total_epochs, max_steps=total_steps, save_best_model=True, save_path=at_model_save_path, max_steps=total_steps, train_dataloader=train_dataloader, distillation=distillation, teacher_model=teacher_model,
distillation=distillation, evaluation_func=evaluation_func) distil_func=distil_func, log_path=log_dir / 'retraining.log', save_best_model=True, save_path=at_model_save_path,
evaluation_func=evaluation_func, device=device)
if not dev_mode: if not dev_mode:
attention_pruned_model.load_state_dict(torch.load(at_model_save_path)) attention_pruned_model.load_state_dict(torch.load(at_model_save_path))
# %% # %%
# Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations. # Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations.
# Finetuning 2000 steps after each iteration, then finetuning 2 epochs after pruning finished. # Finetuning 3000 steps after each pruning iteration, then finetuning 2 epochs after pruning finished.
# #
# NNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code. # NNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.
if not dev_mode: if not dev_mode:
total_epochs = 4 total_epochs = 7
total_steps = None total_steps = None
taylor_pruner_steps = 1000 taylor_pruner_steps = 1000
steps_per_iteration = 2000 steps_per_iteration = 3000
total_pruning_steps = 24000 total_pruning_steps = 36000
distillation = True distillation = True
else: else:
total_epochs = 1 total_epochs = 1
...@@ -445,8 +472,8 @@ else: ...@@ -445,8 +472,8 @@ else:
from nni.compression.pytorch.pruning import TaylorFOWeightPruner from nni.compression.pytorch.pruning import TaylorFOWeightPruner
from nni.compression.pytorch.speedup import ModelSpeedup from nni.compression.pytorch.speedup import ModelSpeedup
distil_training = functools.partial(training, train_dataloader, log_path=log_dir / 'taylor_pruning.log', distil_training = functools.partial(training, train_dataloader=train_dataloader, distillation=distillation,
distillation=distillation, evaluation_func=evaluation_func) teacher_model=teacher_model, distil_func=distil_func, device=device)
traced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8) traced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
evaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion) evaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)
...@@ -461,7 +488,7 @@ for current_epoch in range(total_epochs): ...@@ -461,7 +488,7 @@ for current_epoch in range(total_epochs):
for batch in train_dataloader: for batch in train_dataloader:
if total_steps and current_step >= total_steps: if total_steps and current_step >= total_steps:
break break
# pruning 12 times # pruning with TaylorFOWeightPruner & reinitialize optimizer
if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps: if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:
check_point = attention_pruned_model.state_dict() check_point = attention_pruned_model.state_dict()
pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps) pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)
...@@ -476,9 +503,6 @@ for current_epoch in range(total_epochs): ...@@ -476,9 +503,6 @@ for current_epoch in range(total_epochs):
optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr) optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)
batch.to(device) batch.to(device)
teacher_logits = batch.pop('teacher_logits', None)
optimizer.zero_grad()
# manually schedule lr # manually schedule lr
for params_group in optimizer.param_groups: for params_group in optimizer.param_groups:
params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr
...@@ -487,14 +511,19 @@ for current_epoch in range(total_epochs): ...@@ -487,14 +511,19 @@ for current_epoch in range(total_epochs):
loss = outputs.loss loss = outputs.loss
# distillation # distillation
if teacher_logits is not None: if distillation:
distil_loss = F.kl_div(F.log_softmax(outputs.logits / 2, dim=-1), assert teacher_model is not None
F.softmax(teacher_logits / 2, dim=-1), reduction='batchmean') * (2 ** 2) with torch.no_grad():
teacher_outputs = teacher_model(**batch)
distil_loss = distil_func(outputs, teacher_outputs)
loss = 0.1 * loss + 0.9 * distil_loss loss = 0.1 * loss + 0.9 * distil_loss
optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
current_step += 1 current_step += 1
if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0: if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
result = evaluation_func(attention_pruned_model) result = evaluation_func(attention_pruned_model)
with (log_dir / 'ffn_pruning.log').open('a+') as f: with (log_dir / 'ffn_pruning.log').open('a+') as f:
...@@ -533,28 +562,28 @@ for current_epoch in range(total_epochs): ...@@ -533,28 +562,28 @@ for current_epoch in range(total_epochs):
# - +0.0 / +0.0 # - +0.0 / +0.0
# - 12.56s (x1.00) # - 12.56s (x1.00)
# - 4.05s (x1.00) # - 4.05s (x1.00)
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=5) # * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=5)
# - :ref:`taylor-fo-weight-pruner` # - :ref:`taylor-fo-weight-pruner`
# - 51.39% # - 51.39%
# - 84.25 / 84.96 # - 84.25 / 84.96
# - -0.48 / +0.33 # - -0.48 / +0.33
# - 6.85s (x1.83) # - 6.85s (x1.83)
# - 2.7s (x1.50) # - 2.7s (x1.50)
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=10) # * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=10)
# - :ref:`taylor-fo-weight-pruner` # - :ref:`taylor-fo-weight-pruner`
# - 66.67% # - 66.67%
# - 83.98 / 83.75 # - 83.98 / 83.75
# - -0.75 / -0.88 # - -0.75 / -0.88
# - 4.73s (x2.66) # - 4.73s (x2.66)
# - 2.16s (x1.86) # - 2.16s (x1.86)
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=20) # * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=20)
# - :ref:`taylor-fo-weight-pruner` # - :ref:`taylor-fo-weight-pruner`
# - 77.78% # - 77.78%
# - 83.02 / 83.06 # - 83.02 / 83.06
# - -1.71 / -1.57 # - -1.71 / -1.57
# - 3.35s (x3.75) # - 3.35s (x3.75)
# - 1.72s (x2.35) # - 1.72s (x2.35)
# * - :ref:`movement-pruner` (soft, th=0.1, lambda=30) # * - :ref:`movement-pruner` (soft, sparsity=0.1, regular_scale=30)
# - :ref:`taylor-fo-weight-pruner` # - :ref:`taylor-fo-weight-pruner`
# - 87.04% # - 87.04%
# - 81.24 / 80.99 # - 81.24 / 80.99
......
...@@ -22,7 +22,7 @@ class NaiveQuantizer(Quantizer): ...@@ -22,7 +22,7 @@ class NaiveQuantizer(Quantizer):
config_list : List[Dict] config_list : List[Dict]
List of configurations for quantization. Supported keys: List of configurations for quantization. Supported keys:
- quant_types : List[str] - quant_types : List[str]
Type of quantization you want to apply, currently support 'weight', 'input', 'output'. Type of quantization you want to apply, currently support 'weight'.
- quant_bits : Union[int, Dict[str, int]] - quant_bits : Union[int, Dict[str, int]]
Bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, Bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
when the type is int, all quantization types share same bits length. when the type is int, all quantization types share same bits length.
......
...@@ -228,27 +228,6 @@ class QAT_Quantizer(Quantizer): ...@@ -228,27 +228,6 @@ class QAT_Quantizer(Quantizer):
'quant_dtype': 'uint', 'quant_dtype': 'uint',
'quant_scheme': 'per_tensor_affine' 'quant_scheme': 'per_tensor_affine'
}] }]
**Multi-GPU training**
QAT quantizer natively supports multi-gpu training (DataParallel and DistributedDataParallel). Note that the quantizer
instantiation should happen before you wrap your model with DataParallel or DistributedDataParallel. For example:
.. code-block:: python
from torch.nn.parallel import DistributedDataParallel as DDP
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
model = define_your_model()
model = QAT_Quantizer(model, **other_params) # <--- QAT_Quantizer instantiation
model = DDP(model)
for i in range(epochs):
train(model)
eval(model)
""" """
def __init__(self, model, config_list, optimizer, dummy_input=None): def __init__(self, model, config_list, optimizer, dummy_input=None):
......
...@@ -175,8 +175,8 @@ class EvaluatorBasedPruner(BasicPruner): ...@@ -175,8 +175,8 @@ class EvaluatorBasedPruner(BasicPruner):
else: else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer) self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.using_evaluator = False self.using_evaluator = False
warn_msg = f"The old API ...{','.join(old_api)} will be deprecated after NNI v3.0, " + \ warn_msg = f"The old API {','.join(old_api)} will be deprecated after NNI v3.0, " + \
"please using the new one ...{','.join(new_api)}" f"please using the new one {','.join(new_api)}"
_logger.warning(warn_msg) _logger.warning(warn_msg)
return init_kwargs return init_kwargs
......
...@@ -760,7 +760,10 @@ class TorchEvaluator(Evaluator): ...@@ -760,7 +760,10 @@ class TorchEvaluator(Evaluator):
def evaluate(self) -> float | None | Tuple[float, Dict[str, Any]] | Tuple[None, Dict[str, Any]]: def evaluate(self) -> float | None | Tuple[float, Dict[str, Any]] | Tuple[None, Dict[str, Any]]:
assert self.model is not None assert self.model is not None
assert self.evaluating_func is not None if self.evaluating_func is None:
warn_msg = f'Did not pass evaluation_func to {self.__class__.__name__}, will return None for calling evaluate()'
_logger.warning(warn_msg)
return None
metric = self.evaluating_func(self.model) metric = self.evaluating_func(self.model)
if isinstance(metric, dict): if isinstance(metric, dict):
nni_used_metric = metric.get('default', None) nni_used_metric = metric.get('default', None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment