Commit bf760c80 authored by Victor SANH's avatar Victor SANH
Browse files

finish README

parent 9d7d9b3a
# Movement Pruning: Adaptive Sparsity by Fine-Tuning # Movement Pruning: Adaptive Sparsity by Fine-Tuning
Magnitude pruning is a widely used strategy for reducing model size in pure supervised learning; however, it is less effective in the transfer learning regime that has become standard for state-of-the-art natural language processing applications. We propose the use of *movement pruning*, a simple, deterministic first-order weight pruning method that is more adaptive to pretrained model fine-tuning. Experiments show that when pruning large pretrained language models, movement pruning shows significant improvements in high-sparsity regimes. When combined with distillation, the approach achieves minimal accuracy loss with down to only 3% of the model parameters: *Magnitude pruning is a widely used strategy for reducing model size in pure supervised learning; however, it is less effective in the transfer learning regime that has become standard for state-of-the-art natural language processing applications. We propose the use of *movement pruning*, a simple, deterministic first-order weight pruning method that is more adaptive to pretrained model fine-tuning. Experiments show that when pruning large pretrained language models, movement pruning shows significant improvements in high-sparsity regimes. When combined with distillation, the approach achieves minimal accuracy loss with down to only 3% of the model parameters:*
| Fine-pruning+Distillation<br>(Teacher=BERT-base fine-tuned) | BERT base<br>fine-tuned | Remaining<br>Weights (%) | Magnitude Pruning | L0 Regularization | Movement Pruning | Soft Movement Pruning | | Fine-pruning+Distillation<br>(Teacher=BERT-base fine-tuned) | BERT base<br>fine-tuned | Remaining<br>Weights (%) | Magnitude Pruning | L0 Regularization | Movement Pruning | Soft Movement Pruning |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| SQuAD - Dev<br>EM/F1 | 80.4/88.1 | 10%<br>3% | 70.2/80.1<br>45.5/59.6 | 72.4/81.9<br>64.3/75.8 | 75.6/84.3<br>67.5/78.0 | 76.6/84.9<br>72.7/82.3 | | SQuAD - Dev<br>EM/F1 | 80.4/88.1 | 10%<br>3% | 70.2/80.1<br>45.5/59.6 | 72.4/81.9<br>64.3/75.8 | 75.6/84.3<br>67.5/78.0 | **76.6/84.9**<br>**72.7/82.3** |
| MNLI - Dev<br>acc/MM acc | 84.5/84.9 | 10%<br>3% | 78.3/79.3<br>69.4/70.6 | 78.7/79.7<br>76.0/76.2 | 80.1/80.4<br>76.5/77.4 | 81.2/81.8<br>79.5/80.1 | | MNLI - Dev<br>acc/MM acc | 84.5/84.9 | 10%<br>3% | 78.3/79.3<br>69.4/70.6 | 78.7/79.7<br>76.0/76.2 | 80.1/80.4<br>76.5/77.4 | **81.2/81.8**<br>**79.5/80.1** |
| QQP - Dev<br>acc/F1 | 91.4/88.4 | 10%<br>3% | 79.8/65.0<br>72.4/57.8 | 88.1/82.8<br>87.0/81.9 | 89.7/86.2<br>86.1/81.5 | 90.2/86.8<br>89.1/85.5 | | QQP - Dev<br>acc/F1 | 91.4/88.4 | 10%<br>3% | 79.8/65.0<br>72.4/57.8 | 88.1/82.8<br>87.0/81.9 | 89.7/86.2<br>86.1/81.5 | **90.2/86.8**<br>**89.1/85.5** |
This page contains information on how to fine-prune pre-trained models such as `BERT` to obtain extremely sparse models with movement pruning. In contrast to magnitude pruning which selects weights that are far from 0, movement pruning retains weights that are moving away from 0. This page contains information on how to fine-prune pre-trained models such as `BERT` to obtain extremely sparse models with movement pruning. In contrast to magnitude pruning which selects weights that are far from 0, movement pruning retains weights that are moving away from 0.
...@@ -14,44 +14,46 @@ For more information, we invite you to check out [our paper](https://arxiv.org/a ...@@ -14,44 +14,46 @@ For more information, we invite you to check out [our paper](https://arxiv.org/a
You can also have a look at this fun *Explain Like I'm Five* introductory [slide deck](https://www.slideshare.net/VictorSanh/movement-pruning-explain-like-im-five-234205241). You can also have a look at this fun *Explain Like I'm Five* introductory [slide deck](https://www.slideshare.net/VictorSanh/movement-pruning-explain-like-im-five-234205241).
<div align="center"> <div align="center">
<img src="https://img.pngio.com/emmental-cheese-loaf-salad-png-clipart-cheese-cheese-cake-emmental-cheese-png-728_591.jpg" width="200"> <img src="https://www.seekpng.com/png/detail/166-1669328_how-to-make-emmental-cheese-at-home-icooker.png" width="400">
</div> </div>
## Extreme sparsity and efficient storage ## Extreme sparsity and efficient storage
One promise of extreme pruning is to obtain extremely small models that can be easily sent (and stored) on edge devices. By setting weights to 0., we remove part of the information we need to store, and thus decreasing the memory size. We are able to obtain extremely sparse fine-pruned models with movement pruning: ~95% of the dense performance with ~5% of total remaining weights in the BERT encoder. One promise of extreme pruning is to obtain extremely small models that can be easily sent (and stored) on edge devices. By setting weights to 0., we reduce the amount of information we need to store, and thus decreasing the memory size. We are able to obtain extremely sparse fine-pruned models with movement pruning: ~95% of the dense performance with ~5% of total remaining weights in the BERT encoder.
In [this notebook](https://github.com/huggingface/transformers/blob/master/examples/movement-pruning/Saving_PruneBERT.ipynb), we showcase how we can leverage standard tools that exist out-of-the-box to efficiently store an extremely sparse question answering model (only 6% of total remaining weights in the encoder). We are able to reduce the memory size of the encoder **from the 340MB (the orignal dense BERT) to 11MB**, without any additional training of the model (every operation is *post fine-pruning*). It is sufficiently small to store it on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical) 📎! In [this notebook](https://github.com/huggingface/transformers/blob/master/examples/movement-pruning/Saving_PruneBERT.ipynb), we showcase how we can leverage standard tools that exist out-of-the-box to efficiently store an extremely sparse question answering model (only 6% of total remaining weights in the encoder). We are able to reduce the memory size of the encoder **from the 340MB (the orignal dense BERT) to 11MB**, without any additional training of the model (every operation is performed *post fine-pruning*). It is sufficiently small to store it on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical) 📎!
While movement pruning does not directly optimize for memory footprint (but rather the number of non-null weights), we hypothetize that further memory compression ratios can be achieved with specific quantize aware trainings (see for instance [Q8BERT](https://arxiv.org/abs/1910.06188), [And the Bit Goes Down](https://arxiv.org/abs/1907.05686) or [Quant-Noise](https://arxiv.org/abs/2004.07320)). While movement pruning does not directly optimize for memory footprint (but rather the number of non-null weights), we hypothetize that further memory compression ratios can be achieved with specific quantization aware trainings (see for instance [Q8BERT](https://arxiv.org/abs/1910.06188), [And the Bit Goes Down](https://arxiv.org/abs/1907.05686) or [Quant-Noise](https://arxiv.org/abs/2004.07320)).
## Fine-pruned models ## Fine-pruned models
As examples, we release two English PruneBERT checkpoints (models fine-pruned from a pre-trained `BERT` checkpoint), one on SQuAD and the other on MNLI. As examples, we release two English PruneBERT checkpoints (models fine-pruned from a pre-trained `BERT` checkpoint), one on SQuAD and the other on MNLI.
- **`prunebert-6-finetuned-squad`**: Pre-trained `BERT-base-uncased` fine-pruned with soft movement pruning on SQuAD v1.1. We use an additional distillation signal from `BERT-base-uncased` finetuned on SQuAD. The encoder counts 6% of total non-null weights and reaches 83.8 F1 score (95% of `BERT-base-uncased`'s performance). The model can be accessed with: `pruned_bert = BertForQuestionAnswering.from_pretrained(TODO)` - **`prunebert-base-uncased-6-finepruned-w-distil-squad`**<br/>
- **`prunebert-6-finetuned-mnli`**: Pre-trained `BERT-base-uncased` fine-pruned with soft movement pruning on MNLI. We use an additional distillation signal from `BERT-base-uncased` finetuned on MNLI. The encoder counts 6% of total non-null weights and reaches 80.7 (matched) accuracy (95% of `BERT-base-uncased`'s performance). The model can be accessed with: `pruned_bert = BertForSequenceClassification.from_pretrained(TODO)` Pre-trained `BERT-base-uncased` fine-pruned with soft movement pruning on SQuAD v1.1. We use an additional distillation signal from `BERT-base-uncased` finetuned on SQuAD. The encoder counts 6% of total non-null weights and reaches 83.8 F1 score. The model can be accessed with: `pruned_bert = BertForQuestionAnswering.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad")`
- **`prunebert-base-uncased-6-finepruned-w-distil-mnli`**<br/>
Pre-trained `BERT-base-uncased` fine-pruned with soft movement pruning on MNLI. We use an additional distillation signal from `BERT-base-uncased` finetuned on MNLI. The encoder counts 6% of total non-null weights and reaches 80.7 (matched) accuracy. The model can be accessed with: `pruned_bert = BertForSequenceClassification.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-mnli")`
## How to fine-prune? ## How to fine-prune?
### Setup ### Setup
The code relies on the 🤗 Transformers library. In addition to the dependencies listed in the `examples` folder, you should install a few additional dependencies listed in the `requirements.txt` file: `pip install -r requirements.txt`. The code relies on the 🤗 Transformers library. In addition to the dependencies listed in the [`examples`](https://github.com/huggingface/transformers/tree/master/examples) folder, you should install a few additional dependencies listed in the `requirements.txt` file: `pip install -r requirements.txt`.
Note that we built our experiments on top of a stabilized version of the library (commit `352d5472b0c1dec0f420d606d16747d851b4bda8`): we do not guarantee that everything is still compatible with the latest version of the master branch. Note that we built our experiments on top of a stabilized version of the library (commit https://github.com/huggingface/transformers/commit/352d5472b0c1dec0f420d606d16747d851b4bda8): we do not guarantee that everything is still compatible with the latest version of the master branch.
### Fine-pruning with movement pruning ### Fine-pruning with movement pruning
We detail below how to reproduce the results reported in the paper. We use SQuAD as a running example. Commands (and scripts) can be easily adapted for other tasks. Below, we detail how to reproduce the results reported in the paper. We use SQuAD as a running example. Commands (and scripts) can be easily adapted for other tasks.
The following command fine-prunes a pre-trained `BERT-base` on SQuAD using movement pruning towards 10% of remaining weights (90% sparsity). Note that we freeze all the embeddings modules (from the pre-trained value) and only prune the Fully Connected layers. The following command fine-prunes a pre-trained `BERT-base` on SQuAD using movement pruning towards 15% of remaining weights (85% sparsity). Note that we freeze all the embeddings modules (from their pre-trained value) and only prune the Fully Connected layers in the encoder (12 layers of Transformer Block).
```bash ```bash
SERIALIZAION_DIR=<OUTPUT_DIR> SERIALIZATION_DIR=<OUTPUT_DIR>
SQUAD_DATA=<SQUAD_DATA> SQUAD_DATA=<SQUAD_DATA>
python examples/emmental/masked_run_squad.py \ python examples/movement-pruning/masked_run_squad.py \
--output_dir $SERIALIZAION_DIR \ --output_dir $SERIALIZATION_DIR \
--data_dir $SQUAD_DATA \ --data_dir $SQUAD_DATA \
--train_file train-v1.1.json \ --train_file train-v1.1.json \
--predict_file dev-v1.1.json \ --predict_file dev-v1.1.json \
...@@ -62,7 +64,7 @@ python examples/emmental/masked_run_squad.py \ ...@@ -62,7 +64,7 @@ python examples/emmental/masked_run_squad.py \
--warmup_steps 5400 \ --warmup_steps 5400 \
--num_train_epochs 10 \ --num_train_epochs 10 \
--learning_rate 3e-5 --mask_scores_learning_rate 1e-2 \ --learning_rate 3e-5 --mask_scores_learning_rate 1e-2 \
--initial_threshold 1 --final_threshold 0.1 \ --initial_threshold 1 --final_threshold 0.15 \
--initial_warmup 1 --final_warmup 2 \ --initial_warmup 1 --final_warmup 2 \
--pruning_method topK --mask_init constant --mask_scale 0. --pruning_method topK --mask_init constant --mask_scale 0.
``` ```
...@@ -74,7 +76,7 @@ We can also explore other fine-pruning methods by changing the `pruning_method` ...@@ -74,7 +76,7 @@ We can also explore other fine-pruning methods by changing the `pruning_method`
Soft movement pruning Soft movement pruning
```bash ```bash
python examples/movement-pruning/masked_run_squad.py \ python examples/movement-pruning/masked_run_squad.py \
--output_dir $SERIALIZAION_DIR \ --output_dir $SERIALIZATION_DIR \
--data_dir $SQUAD_DATA \ --data_dir $SQUAD_DATA \
--train_file train-v1.1.json \ --train_file train-v1.1.json \
--predict_file dev-v1.1.json \ --predict_file dev-v1.1.json \
...@@ -88,13 +90,13 @@ python examples/movement-pruning/masked_run_squad.py \ ...@@ -88,13 +90,13 @@ python examples/movement-pruning/masked_run_squad.py \
--initial_threshold 0 --final_threshold 0.1 \ --initial_threshold 0 --final_threshold 0.1 \
--initial_warmup 1 --final_warmup 2 \ --initial_warmup 1 --final_warmup 2 \
--pruning_method sigmoied_threshold --mask_init constant --mask_scale 0. \ --pruning_method sigmoied_threshold --mask_init constant --mask_scale 0. \
--regularization l1 --final_lambda 500. --regularization l1 --final_lambda 400.
``` ```
L0 regularization L0 regularization
```bash ```bash
python examples/movement-pruning/masked_run_squad.py \ python examples/movement-pruning/masked_run_squad.py \
--output_dir $SERIALIZAION_DIR \ --output_dir $SERIALIZATION_DIR \
--data_dir $SQUAD_DATA \ --data_dir $SQUAD_DATA \
--train_file train-v1.1.json \ --train_file train-v1.1.json \
--predict_file dev-v1.1.json \ --predict_file dev-v1.1.json \
...@@ -108,7 +110,7 @@ python examples/movement-pruning/masked_run_squad.py \ ...@@ -108,7 +110,7 @@ python examples/movement-pruning/masked_run_squad.py \
--initial_threshold 1. --final_threshold 1. \ --initial_threshold 1. --final_threshold 1. \
--initial_warmup 1 --final_warmup 1 \ --initial_warmup 1 --final_warmup 1 \
--pruning_method l0 --mask_init constant --mask_scale 2.197 \ --pruning_method l0 --mask_init constant --mask_scale 2.197 \
--regularization l0 --final_lambda 175. --regularization l0 --final_lambda 125.
``` ```
Iterative Magnitude Pruning Iterative Magnitude Pruning
...@@ -125,7 +127,7 @@ python examples/movement-pruning/masked_run_squad.py \ ...@@ -125,7 +127,7 @@ python examples/movement-pruning/masked_run_squad.py \
--warmup_steps 5400 \ --warmup_steps 5400 \
--num_train_epochs 10 \ --num_train_epochs 10 \
--learning_rate 3e-5 \ --learning_rate 3e-5 \
--initial_threshold 1 --final_threshold 0.1 \ --initial_threshold 1 --final_threshold 0.15 \
--initial_warmup 1 --final_warmup 2 \ --initial_warmup 1 --final_warmup 2 \
--pruning_method magnitude --pruning_method magnitude
``` ```
...@@ -134,30 +136,30 @@ python examples/movement-pruning/masked_run_squad.py \ ...@@ -134,30 +136,30 @@ python examples/movement-pruning/masked_run_squad.py \
**Counting parameters** **Counting parameters**
Regularization based pruning methods (soft movement pruning and L0 regularization) rely on the penalty to induce sparsity, while the multiplicative coefficient controls the sparsity level. Regularization based pruning methods (soft movement pruning and L0 regularization) rely on the penalty to induce sparsity. The multiplicative coefficient controls the sparsity level.
To obtain the effective sparsity level in the encoder, we simply count the number of activated (non-null) weights: To obtain the effective sparsity level in the encoder, we simply count the number of activated (non-null) weights:
```bash ```bash
python examples/movement-pruning/count_parameters.py \ python examples/movement-pruning/count_parameters.py \
--pruning_method sigmoied_threshold \ --pruning_method sigmoied_threshold \
--threshold 0.1 \ --threshold 0.1 \
--serialization_dir $SERIALIZAION_DIR --serialization_dir $SERIALIZATION_DIR
``` ```
**Pruning once for all** **Pruning once for all**
Once the model has been fine-pruned, the pruned weights can be set to 0 once for all (reducing the amount of information to store). In our running experiments, we can convert a `MaskedBertForQuestionAnswering` (a BERT model augmented to enable on-the-fly pruning capabilities) to a standard `BertForQuestionAnswering`: Once the model has been fine-pruned, the pruned weights can be set to 0. once for all (reducing the amount of information to store). In our running experiments, we can convert a `MaskedBertForQuestionAnswering` (a BERT model augmented to enable on-the-fly pruning capabilities) to a standard `BertForQuestionAnswering`:
```bash ```bash
python examples/movement-pruning/bertarize.py \ python examples/movement-pruning/bertarize.py \
--pruning_method sigmoied_threshold \ --pruning_method sigmoied_threshold \
--threshold 0.1 \ --threshold 0.1 \
--model_name_or_path $SERIALIZAION_DIR --model_name_or_path $SERIALIZATION_DIR
``` ```
## Hyper-parameters ## Hyper-parameters
For reproducibility purposes, we share the detailed results presented in the paper. This [spreadsheet](TODO) exhaustively describes the individual hyper-parameters used for each data point. For reproducibility purposes, we share the detailed results presented in the paper. These [tables](https://docs.google.com/spreadsheets/d/17JgRq_OFFTniUrz6BZWW_87DjFkKXpI1kYDSsseT_7g/edit?usp=sharing) exhaustively describe the individual hyper-parameters used for each data point.
## Inference speed ## Inference speed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment