Commit 86a63070 authored by erenup's avatar erenup
Browse files

Merge branch 'huggingface/master'

parents b5d73976 82f6abd9
...@@ -5,11 +5,35 @@ similar API between the different models. ...@@ -5,11 +5,35 @@ similar API between the different models.
| Section | Description | | Section | Description |
|----------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------| |----------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [TensorFlow 2.0 models on GLUE](#TensorFlow-2.0-Bert-models-on-GLUE) | Examples running BERT TensorFlow 2.0 model on the GLUE tasks.
| [Language Model fine-tuning](#language-model-fine-tuning) | Fine-tuning the library models for language modeling on a text dataset. Causal language modeling for GPT/GPT-2, masked language modeling for BERT/RoBERTa. | | [Language Model fine-tuning](#language-model-fine-tuning) | Fine-tuning the library models for language modeling on a text dataset. Causal language modeling for GPT/GPT-2, masked language modeling for BERT/RoBERTa. |
| [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. | | [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. |
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. | | [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. | | [SQuAD](#squad) | Using BERT/XLM/XLNet/RoBERTa for question answering, examples with distributed training. |
| [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks. | [Multiple Choice](#multiple-choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.
| [Named Entity Recognition](#named-entity-recognition) | Using BERT for Named Entity Recognition (NER) on the CoNLL 2003 dataset, examples with distributed training. |
## TensorFlow 2.0 Bert models on GLUE
Based on the script [`run_tf_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/run_tf_glue.py).
Fine-tuning the library TensorFlow 2.0 Bert model for sequence classification on the MRPC task of the GLUE benchmark: [General Language Understanding Evaluation](https://gluebenchmark.com/).
This script has an option for mixed precision (Automatic Mixed Precision / AMP) to run models on Tensor Cores (NVIDIA Volta/Turing GPUs) and future hardware and an option for XLA, which uses the XLA compiler to reduce model runtime.
Options are toggled using `USE_XLA` or `USE_AMP` variables in the script.
These options and the below benchmark are provided by @tlkh.
Quick benchmarks from the script (no other modifications):
| GPU | Mode | Time (2nd epoch) | Val Acc (3 runs) |
| --------- | -------- | ----------------------- | ----------------------|
| Titan V | FP32 | 41s | 0.8438/0.8281/0.8333 |
| Titan V | AMP | 26s | 0.8281/0.8568/0.8411 |
| V100 | FP32 | 35s | 0.8646/0.8359/0.8464 |
| V100 | AMP | 22s | 0.8646/0.8385/0.8411 |
| 1080 Ti | FP32 | 55s | - |
Mixed precision (AMP) reduces the training time considerably for the same hardware and hyper-parameters (same batch size was used).
## Language model fine-tuning ## Language model fine-tuning
...@@ -283,17 +307,17 @@ The results are the following: ...@@ -283,17 +307,17 @@ The results are the following:
loss = 0.04755385363816904 loss = 0.04755385363816904
``` ```
##Multiple Choice ## Multiple Choice
Based on the script [`run_multiple_choice.py`](). Based on the script [`run_multiple_choice.py`]().
#### Fine-tuning on SWAG #### Fine-tuning on SWAG
Download [swag](https://github.com/rowanz/swagaf/tree/master/data) data Download [swag](https://github.com/rowanz/swagaf/tree/master/data) data
``` ```bash
#training on 4 tesla V100(16GB) GPUS #training on 4 tesla V100(16GB) GPUS
export SWAG_DIR=/path/to/swag_data_dir export SWAG_DIR=/path/to/swag_data_dir
python ./examples/single_model_scripts/run_multiple_choice.py \ python ./examples/run_multiple_choice.py \
--model_type roberta \ --model_type roberta \
--task_name swag \ --task_name swag \
--model_name_or_path roberta-base \ --model_name_or_path roberta-base \
...@@ -390,3 +414,103 @@ exact_match = 86.91 ...@@ -390,3 +414,103 @@ exact_match = 86.91
This fine-tuneds model is available as a checkpoint under the reference This fine-tuneds model is available as a checkpoint under the reference
`bert-large-uncased-whole-word-masking-finetuned-squad`. `bert-large-uncased-whole-word-masking-finetuned-squad`.
## Named Entity Recognition
Based on the script [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_ner.py).
This example fine-tune Bert Multilingual on GermEval 2014 (German NER).
Details and results for the fine-tuning provided by @stefan-it.
### Data (Download and pre-processing steps)
Data can be obtained from the [GermEval 2014](https://sites.google.com/site/germeval2014ner/data) shared task page.
Here are the commands for downloading and pre-processing train, dev and test datasets. The original data format has four (tab-separated) columns, in a pre-processing step only the two relevant columns (token and outer span NER annotation) are extracted:
```bash
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-train.tsv?attredirects=0&d=1' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-dev.tsv?attredirects=0&d=1' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-test.tsv?attredirects=0&d=1' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp
```
The GermEval 2014 dataset contains some strange "control character" tokens like `'\x96', '\u200e', '\x95', '\xad' or '\x80'`. One problem with these tokens is, that `BertTokenizer` returns an empty token for them, resulting in misaligned `InputExample`s. I wrote a script that a) filters these tokens and b) splits longer sentences into smaller ones (once the max. subtoken length is reached).
```bash
wget "https://raw.githubusercontent.com/stefan-it/fine-tuned-berts-seq/master/scripts/preprocess.py"
```
Let's define some variables that we need for further pre-processing steps and training the model:
```bash
export MAX_LENGTH=128
export BERT_MODEL=bert-base-multilingual-cased
```
Run the pre-processing script on training, dev and test datasets:
```bash
python3 preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt
python3 preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt
python3 preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt
```
The GermEval 2014 dataset has much more labels than CoNLL-2002/2003 datasets, so an own set of labels must be used:
```bash
cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt
```
### Training
Additional environment variables must be set:
```bash
export OUTPUT_DIR=germeval-model
export BATCH_SIZE=32
export NUM_EPOCHS=3
export SAVE_STEPS=750
export SEED=1
```
To start training, just run:
```bash
python3 run_ner.py --data_dir ./ \
--model_type bert \
--labels ./labels.txt \
--model_name_or_path $BERT_MODEL \
--output_dir $OUTPUT_DIR \
--max_seq_length $MAX_LENGTH \
--num_train_epochs $NUM_EPOCHS \
--per_gpu_train_batch_size $BATCH_SIZE \
--save_steps $SAVE_STEPS \
--seed $SEED \
--do_train \
--do_eval \
--do_predict
```
If your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
### Evaluation
Evaluation on development dataset outputs the following for our example:
```bash
10/04/2019 00:42:06 - INFO - __main__ - ***** Eval results *****
10/04/2019 00:42:06 - INFO - __main__ - f1 = 0.8623348017621146
10/04/2019 00:42:06 - INFO - __main__ - loss = 0.07183869666975543
10/04/2019 00:42:06 - INFO - __main__ - precision = 0.8467916366258111
10/04/2019 00:42:06 - INFO - __main__ - recall = 0.8784592370979806
```
On the test dataset the following results could be achieved:
```bash
10/04/2019 00:42:42 - INFO - __main__ - ***** Eval results *****
10/04/2019 00:42:42 - INFO - __main__ - f1 = 0.8614389652384803
10/04/2019 00:42:42 - INFO - __main__ - loss = 0.07064602487454782
10/04/2019 00:42:42 - INFO - __main__ - precision = 0.8604651162790697
10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085
```
This diff is collapsed.
...@@ -31,9 +31,13 @@ import torch ...@@ -31,9 +31,13 @@ import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset) TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from tensorboardX import SummaryWriter try:
from torch.utils.tensorboard import SummaryWriter
except:
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, BertConfig, from transformers import (WEIGHTS_NAME, BertConfig,
BertForMultipleChoice, BertTokenizer) BertForMultipleChoice, BertTokenizer)
......
# DistilBERT # Distil*
This folder contains the original code used to train DistilBERT as well as examples showcasing how to use DistilBERT. This folder contains the original code used to train Distil* as well as examples showcasing how to use DistilBERT and DistilGPT2.
**2019, October 3rd - Update** We release our [NeurIPS workshop paper](https://arxiv.org/abs/1910.01108) explaining our approach on **DistilBERT**. It includes updated results and further experiments. We applied the same method to GPT2 and release the weights of **DistilGPT2**. DistilGPT2 is two times faster and 33% smaller than GPT2. **The paper superseeds our [previous blogpost](https://medium.com/huggingface/distilbert-8cf3380435b5) with a different distillation loss and better performances. Please use the paper as a reference when comparing/reporting results on DistilBERT.**
**2019, September 19th - Update:** We fixed bugs in the code and released an upadted version of the weights trained with a modification of the distillation loss. DistilBERT now reaches 97% of `BERT-base`'s performance on GLUE, and 86.9 F1 score on SQuAD v1.1 dev set (compared to 88.5 for `BERT-base`). We will publish a formal write-up of our approach in the near future! **2019, September 19th - Update:** We fixed bugs in the code and released an upadted version of the weights trained with a modification of the distillation loss. DistilBERT now reaches 97% of `BERT-base`'s performance on GLUE, and 86.9 F1 score on SQuAD v1.1 dev set (compared to 88.5 for `BERT-base`). We will publish a formal write-up of our approach in the near future!
## What is DistilBERT ## What is Distil*
Distil* is a class of compressed models that started with DistilBERT. DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving 97% of BERT's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production.
DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving 97% of BERT's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production. We have applied the same method to GPT2 and release the weights of the compressed model. On the [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) benchmark, GPT2 reaches a perplexity on the test set of 15.0 compared to 18.5 for DistilGPT2 (after fine-tuning on the train set).
For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5 For more information on DistilBERT, please refer to our [NeurIPS workshop paper](https://arxiv.org/abs/1910.01108).
). *Please note that we will publish a formal write-up with updated and more complete results in the near future (September 19th).*
Here's the updated results on the dev sets of GLUE: Here are the results on the dev sets of GLUE:
| Model | Macro-score | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2 | STS-B | WNLI | | Model | Macro-score | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2| STS-B| WNLI |
| :---: | :---: | :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:| | :---: | :---: | :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:|
| BERT-base | **77.6** | 48.9 | 84.3 | 88.6 | 89.3 | 89.5 | 71.3 | 91.7 | 91.2 | 43.7 | | BERT-base | **77.6** | 48.9 | 84.3 | 88.6 | 89.3 | 89.5 | 71.3 | 91.7 | 91.2 | 43.7 |
| DistilBERT | **75.2** | 49.1 | 81.8 | 90.2 | 87.0 | 89.2 | 62.9 | 92.7 | 90.7 | 44.4 | | DistilBERT | **76.8** | 49.1 | 81.8 | 90.2 | 90.2 | 89.2 | 62.9 | 92.7 | 90.7 | 44.4 |
## Setup ## Setup
This part of the library has only be tested with Python3.6+. There are few specific dependencies to install before launching a distillation, you can install them with the command `pip install -r requirements.txt`. This part of the library has only be tested with Python3.6+. There are few specific dependencies to install before launching a distillation, you can install them with the command `pip install -r requirements.txt`.
**Important note:** The training scripts have been updated to support PyTorch v1.2.0 (there are breakings changes compared to v1.1.0). It is important to note that there is a small internal bug in the current version of PyTorch available on pip that causes a memory leak in our training/distillation. It has been recently fixed and will likely be integrated into the next release. For the moment, we recommend to [compile PyTorch from source](https://github.com/pytorch/pytorch#from-source). Please refer to [issue 1179](https://github.com/huggingface/transformers/issues/1179) for more details. **Important note:** The training scripts have been updated to support PyTorch v1.2.0 (there are breakings changes compared to v1.1.0).
## How to use DistilBERT ## How to use DistilBERT
Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT): Transformers includes two pre-trained Distil* models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT):
- `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters. - `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters.
- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.9 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score). - `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.9 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score).
- `distilgpt2`: DistilGPT2 English language model pretrained with the supervision of `gpt2` (the smallest version of GPT2) on [OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/), a reproduction of OpenAI's WebText dataset and . The model has 6 layers, 768 dimension and 12 heads, totalizing 82M (compared to 124M parameters for GPT2). On average, DistilGPT2 is two times faster than GPT2.
- and more to come! 🤗🤗🤗
Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models. Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models.
...@@ -42,9 +47,11 @@ outputs = model(input_ids) ...@@ -42,9 +47,11 @@ outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
``` ```
## How to train DistilBERT Similarly, using DistilGPT2 simply consists in calling the GPT2 classes from a different pretrained checkpoint: `model = GPT2Model.from_pretrained('distilgpt2')`.
In the following, we will explain how you can train your own compressed model. ## How to train Distil*
In the following, we will explain how you can train DistilBERT.
### A. Preparing the data ### A. Preparing the data
...@@ -57,7 +64,8 @@ First, we will binarize the data, i.e. tokenize the data and convert each token ...@@ -57,7 +64,8 @@ First, we will binarize the data, i.e. tokenize the data and convert each token
```bash ```bash
python scripts/binarized_data.py \ python scripts/binarized_data.py \
--file_path data/dump.txt \ --file_path data/dump.txt \
--bert_tokenizer bert-base-uncased \ --tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file data/binarized_text --dump_file data/binarized_text
``` ```
...@@ -66,7 +74,8 @@ Our implementation of masked language modeling loss follows [XLM](https://github ...@@ -66,7 +74,8 @@ Our implementation of masked language modeling loss follows [XLM](https://github
```bash ```bash
python scripts/token_counts.py \ python scripts/token_counts.py \
--data_file data/binarized_text.bert-base-uncased.pickle \ --data_file data/binarized_text.bert-base-uncased.pickle \
--token_counts_dump data/token_counts.bert-base-uncased.pickle --token_counts_dump data/token_counts.bert-base-uncased.pickle \
--vocab_size 30522
``` ```
### B. Training ### B. Training
...@@ -75,6 +84,12 @@ Training with distillation is really simple once you have pre-processed the data ...@@ -75,6 +84,12 @@ Training with distillation is really simple once you have pre-processed the data
```bash ```bash
python train.py \ python train.py \
--student_type distilbert \
--student_config training_configs/distilbert-base-uncased.json \
--teacher_type bert \
--teacher_name bert-base-uncased \
--alpha_ce 5.0 --alpha_mlm 2.0 --alpha_cos 1.0 --mlm \
--freeze_pos_embs \
--dump_path serialization_dir/my_first_training \ --dump_path serialization_dir/my_first_training \
--data_file data/binarized_text.bert-base-uncased.pickle \ --data_file data/binarized_text.bert-base-uncased.pickle \
--token_counts data/token_counts.bert-base-uncased.pickle \ --token_counts data/token_counts.bert-base-uncased.pickle \
...@@ -83,7 +98,7 @@ python train.py \ ...@@ -83,7 +98,7 @@ python train.py \
By default, this will launch a training on a single GPU (even if more are available on the cluster). Other parameters are available in the command line, please look in `train.py` or run `python train.py --help` to list them. By default, this will launch a training on a single GPU (even if more are available on the cluster). Other parameters are available in the command line, please look in `train.py` or run `python train.py --help` to list them.
We highly encourage you to use distributed training for training DistilBert as the training corpus is quite large. Here's an example that runs a distributed training on a single node having 4 GPUs: We highly encourage you to use distributed training for training DistilBERT as the training corpus is quite large. Here's an example that runs a distributed training on a single node having 4 GPUs:
```bash ```bash
export NODE_RANK=0 export NODE_RANK=0
...@@ -105,11 +120,30 @@ python -m torch.distributed.launch \ ...@@ -105,11 +120,30 @@ python -m torch.distributed.launch \
train.py \ train.py \
--force \ --force \
--n_gpu $WORLD_SIZE \ --n_gpu $WORLD_SIZE \
--student_type distilbert \
--student_config training_configs/distilbert-base-uncased.json \
--teacher_type bert \
--teacher_name bert-base-uncased \
--alpha_ce 0.33 --alpha_mlm 0.33 --alpha_cos 0.33 --mlm \
--freeze_pos_embs \
--dump_path serialization_dir/my_first_training \
--data_file data/binarized_text.bert-base-uncased.pickle \ --data_file data/binarized_text.bert-base-uncased.pickle \
--token_counts data/token_counts.bert-base-uncased.pickle \ --token_counts data/token_counts.bert-base-uncased.pickle
--dump_path serialization_dir/my_first_distillation
``` ```
**Tips:** Starting distillated training with good initialization of the model weights is crucial to reach decent performance. In our experiments, we initialized our model from a few layers of the teacher (Bert) itself! Please refer to `scripts/extract_for_distil.py` to create a valid initialization checkpoint and use `--from_pretrained_weights` and `--from_pretrained_config` arguments to use this initialization for the distilled training! **Tips:** Starting distillated training with good initialization of the model weights is crucial to reach decent performance. In our experiments, we initialized our model from a few layers of the teacher (Bert) itself! Please refer to `scripts/extract.py` and `scripts/extract_distilbert.py` to create a valid initialization checkpoint and use `--student_pretrained_weights` argument to use this initialization for the distilled training!
Happy distillation! Happy distillation!
## Citation
If you find the ressource useful, you should cite the following paper:
```
@inproceedings{sanh2019distilbert,
title={DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter},
author={Sanh, Victor and Debut, Lysandre and Chaumond, Julien and Wolf, Thomas},
booktitle={NeurIPS EMC^2 Workshop},
year={2019}
}
```
\ No newline at end of file
...@@ -12,14 +12,13 @@ ...@@ -12,14 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" The distiller to distil DistilBERT """ The distiller to distil the student.
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
""" """
import os import os
import math import math
import psutil import psutil
import time import time
from tensorboardX import SummaryWriter
from tqdm import trange, tqdm from tqdm import trange, tqdm
import numpy as np import numpy as np
import psutil import psutil
...@@ -28,16 +27,24 @@ import torch ...@@ -28,16 +27,24 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import RandomSampler, BatchSampler, DataLoader
try:
from torch.utils.tensorboard import SummaryWriter
except:
from tensorboardX import SummaryWriter
from transformers import WarmupLinearSchedule from transformers import WarmupLinearSchedule
from utils import logger from utils import logger
from dataset import Dataset from lm_seqs_dataset import LmSeqsDataset
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
class Distiller: class Distiller:
def __init__(self, def __init__(self,
params: dict, params: dict,
dataloader: Dataset, dataset: LmSeqsDataset,
token_probs: torch.tensor, token_probs: torch.tensor,
student: nn.Module, student: nn.Module,
teacher: nn.Module): teacher: nn.Module):
...@@ -50,33 +57,47 @@ class Distiller: ...@@ -50,33 +57,47 @@ class Distiller:
self.student = student self.student = student
self.teacher = teacher self.teacher = teacher
self.dataloader = dataloader self.student_config = student.config
if self.params.n_gpu > 1: self.vocab_size = student.config.vocab_size
self.dataloader.split()
self.get_iterator(seed=params.seed) if params.n_gpu <= 1:
sampler = RandomSampler(dataset)
else:
sampler = DistributedSampler(dataset)
if params.group_by_size:
groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
else:
sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)
self.dataloader = DataLoader(dataset=dataset,
batch_sampler=sampler,
collate_fn=dataset.batch_sequences)
self.temperature = params.temperature self.temperature = params.temperature
assert self.temperature > 0. assert self.temperature > 0.
self.alpha_ce = params.alpha_ce self.alpha_ce = params.alpha_ce
self.alpha_mlm = params.alpha_mlm self.alpha_mlm = params.alpha_mlm
self.alpha_clm = params.alpha_clm
self.alpha_mse = params.alpha_mse self.alpha_mse = params.alpha_mse
self.alpha_cos = params.alpha_cos self.alpha_cos = params.alpha_cos
assert self.alpha_ce >= 0.
assert self.alpha_mlm >= 0. self.mlm = params.mlm
assert self.alpha_mse >= 0. if self.mlm:
assert self.alpha_cos >= 0. logger.info(f'Using MLM loss for LM step.')
assert self.alpha_ce + self.alpha_mlm + self.alpha_mse + self.alpha_cos > 0. self.mlm_mask_prop = params.mlm_mask_prop
assert 0.0 <= self.mlm_mask_prop <= 1.0
self.mlm_mask_prop = params.mlm_mask_prop assert params.word_mask + params.word_keep + params.word_rand == 1.0
assert 0.0 <= self.mlm_mask_prop <= 1.0 self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
assert params.word_mask + params.word_keep + params.word_rand == 1.0 self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand]) self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs if self.fp16:
self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs self.pred_probs = self.pred_probs.half()
if self.fp16: self.token_probs = self.token_probs.half()
self.pred_probs = self.pred_probs.half() else:
self.token_probs = self.token_probs.half() logger.info(f'Using CLM loss for LM step.')
self.epoch = 0 self.epoch = 0
self.n_iter = 0 self.n_iter = 0
...@@ -86,12 +107,13 @@ class Distiller: ...@@ -86,12 +107,13 @@ class Distiller:
self.last_loss = 0 self.last_loss = 0
self.last_loss_ce = 0 self.last_loss_ce = 0
self.last_loss_mlm = 0 self.last_loss_mlm = 0
self.last_loss_clm = 0
if self.alpha_mse > 0.: self.last_loss_mse = 0 if self.alpha_mse > 0.: self.last_loss_mse = 0
if self.alpha_cos > 0.: self.last_loss_cos = 0 if self.alpha_cos > 0.: self.last_loss_cos = 0
self.last_log = 0 self.last_log = 0
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean') self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
self.mse_loss_fct = nn.MSELoss(reduction='sum') self.mse_loss_fct = nn.MSELoss(reduction='sum')
if self.alpha_cos > 0.: if self.alpha_cos > 0.:
...@@ -99,7 +121,7 @@ class Distiller: ...@@ -99,7 +121,7 @@ class Distiller:
logger.info('--- Initializing model optimizer') logger.info('--- Initializing model optimizer')
assert params.gradient_accumulation_steps >= 1 assert params.gradient_accumulation_steps >= 1
self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1 self.num_steps_epoch = len(self.dataloader)
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1 num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.weight']
...@@ -140,43 +162,18 @@ class Distiller: ...@@ -140,43 +162,18 @@ class Distiller:
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.") logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
self.student = DistributedDataParallel(self.student, self.student = DistributedDataParallel(self.student,
device_ids=[params.local_rank], device_ids=[params.local_rank],
output_device=params.local_rank) output_device=params.local_rank,
find_unused_parameters=True)
self.is_master = params.is_master self.is_master = params.is_master
if self.is_master: if self.is_master:
logger.info('--- Initializing Tensorboard') logger.info('--- Initializing Tensorboard')
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train')) self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train'))
self.tensorboard.add_text(tag='config', text_string=str(self.params), global_step=0) self.tensorboard.add_text(tag='config/training', text_string=str(self.params), global_step=0)
self.tensorboard.add_text(tag='config/student', text_string=str(self.student_config), global_step=0)
def get_iterator(self, def prepare_batch_mlm(self,
seed: int = None): batch):
"""
Initialize the data iterator.
Each process has its own data iterator (iterating on his own random portion of the dataset).
Input:
------
seed: `int` - The random seed.
"""
logger.info('--- Initializing Data Iterator')
self.data_iterator = self.dataloader.get_iterator(seed=seed)
def get_batch(self):
"""
Call the data iterator to output a new batch.
If the data iterator went through the whole dataset, create a new iterator.
"""
assert hasattr(self, 'data_iterator')
try:
x = next(self.data_iterator)
except StopIteration:
logger.warning('--- Went through the whole dataset. Creating new data iterator.')
self.data_iterator = self.dataloader.get_iterator()
x = next(self.data_iterator)
return x
def prepare_batch(self,
batch):
""" """
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM. Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
...@@ -222,7 +219,7 @@ class Distiller: ...@@ -222,7 +219,7 @@ class Distiller:
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item() assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
_token_ids_real = token_ids[pred_mask] _token_ids_real = token_ids[pred_mask]
_token_ids_rand = _token_ids_real.clone().random_(self.params.vocab_size) _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token']) _token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token'])
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True) probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long() _token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
...@@ -230,8 +227,41 @@ class Distiller: ...@@ -230,8 +227,41 @@ class Distiller:
mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
return token_ids, attn_mask, mlm_labels return token_ids, attn_mask, mlm_labels
def prepare_batch_clm(self,
batch):
"""
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.
Input:
------
batch: `Tuple`
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
Output:
-------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict.
"""
token_ids, lengths = batch
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
assert token_ids.size(0) == lengths.size(0)
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
clm_labels[~attn_mask] = -1 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
return token_ids, attn_mask, clm_labels
def round_batch(self, def round_batch(self,
x: torch.tensor, x: torch.tensor,
lengths: torch.tensor): lengths: torch.tensor):
...@@ -269,7 +299,10 @@ class Distiller: ...@@ -269,7 +299,10 @@ class Distiller:
if ml1 % 8 != 0: if ml1 % 8 != 0:
pad = 8 - (ml1 % 8) pad = 8 - (ml1 % 8)
ml2 = ml1 + pad ml2 = ml1 + pad
pad_id = self.params.special_tok_ids['pad_token'] if self.mlm:
pad_id = self.params.special_tok_ids['pad_token']
else:
pad_id = self.params.special_tok_ids['unk_token']
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id) padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
x = torch.cat([x, padding_tensor], 1) x = torch.cat([x, padding_tensor], 1)
assert x.size() == (bs2, ml2) assert x.size() == (bs2, ml2)
...@@ -292,14 +325,16 @@ class Distiller: ...@@ -292,14 +325,16 @@ class Distiller:
if self.multi_gpu: if self.multi_gpu:
torch.distributed.barrier() torch.distributed.barrier()
iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0]) iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
for __ in range(self.num_steps_epoch): for batch in iter_bar:
batch = self.get_batch()
if self.params.n_gpu > 0: if self.params.n_gpu > 0:
batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch) batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch)
token_ids, attn_mask, mlm_labels = self.prepare_batch(batch=batch)
self.step(input_ids=token_ids, attention_mask=attn_mask, mlm_labels=mlm_labels) if self.mlm:
token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
else:
token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch)
self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
iter_bar.update() iter_bar.update()
iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}', iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}',
...@@ -317,7 +352,7 @@ class Distiller: ...@@ -317,7 +352,7 @@ class Distiller:
def step(self, def step(self,
input_ids: torch.tensor, input_ids: torch.tensor,
attention_mask: torch.tensor, attention_mask: torch.tensor,
mlm_labels: torch.tensor): lm_labels: torch.tensor):
""" """
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation), One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
and possibly a parameter update (depending on the gradient accumulation). and possibly a parameter update (depending on the gradient accumulation).
...@@ -326,17 +361,22 @@ class Distiller: ...@@ -326,17 +361,22 @@ class Distiller:
------ ------
input_ids: `torch.tensor(bs, seq_length)` - The token ids. input_ids: `torch.tensor(bs, seq_length)` - The token ids.
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention. attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
""" """
s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size) if self.mlm:
with torch.no_grad(): s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size) with torch.no_grad():
t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size)
else:
s_logits, _, s_hidden_states = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
with torch.no_grad():
t_logits, _, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size)
assert s_logits.size() == t_logits.size() assert s_logits.size() == t_logits.size()
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 #https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2 #https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
if self.params.restrict_ce_to_mask: if self.params.restrict_ce_to_mask:
mask = (mlm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size) mask = (lm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
else: else:
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size) mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
...@@ -348,13 +388,20 @@ class Distiller: ...@@ -348,13 +388,20 @@ class Distiller:
loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1), loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1),
F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2 F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2
loss = self.alpha_ce*loss_ce loss = self.alpha_ce*loss_ce
if self.alpha_mlm > 0.: if self.alpha_mlm > 0.:
loss_mlm = self.mlm_loss_fct(s_logits.view(-1, s_logits.size(-1)), mlm_labels.view(-1)) loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
loss += self.alpha_mlm * loss_mlm loss += self.alpha_mlm * loss_mlm
if self.alpha_clm > 0.:
shift_logits = s_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous()
loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
loss += self.alpha_clm * loss_clm
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction
loss += self.alpha_mse * loss_mse loss += self.alpha_mse * loss_mse
if self.alpha_cos > 0.: if self.alpha_cos > 0.:
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim) s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim) t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
...@@ -376,6 +423,8 @@ class Distiller: ...@@ -376,6 +423,8 @@ class Distiller:
self.last_loss_ce = loss_ce.item() self.last_loss_ce = loss_ce.item()
if self.alpha_mlm > 0.: if self.alpha_mlm > 0.:
self.last_loss_mlm = loss_mlm.item() self.last_loss_mlm = loss_mlm.item()
if self.alpha_clm > 0.:
self.last_loss_clm = loss_clm.item()
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
self.last_loss_mse = loss_mse.item() self.last_loss_mse = loss_mse.item()
if self.alpha_cos > 0.: if self.alpha_cos > 0.:
...@@ -452,6 +501,8 @@ class Distiller: ...@@ -452,6 +501,8 @@ class Distiller:
self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter)
if self.alpha_mlm > 0.: if self.alpha_mlm > 0.:
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter)
if self.alpha_clm > 0.:
self.tensorboard.add_scalar(tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter)
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
if self.alpha_cos > 0.: if self.alpha_cos > 0.:
......
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Adapted from PyTorch Vision (https://github.com/pytorch/vision/blob/master/references/detection/group_by_aspect_ratio.py)
"""
import bisect
import copy
from collections import defaultdict
import numpy as np
from torch.utils.data.sampler import BatchSampler, Sampler
from utils import logger
def _quantize(x, bins):
bins = copy.deepcopy(bins)
bins = sorted(bins)
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
return quantized
def create_lengths_groups(lengths, k=0):
bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10]
groups = _quantize(lengths, bins)
# count number of elements per group
counts = np.unique(groups, return_counts=True)[1]
fbins = [0] + bins + [np.inf]
logger.info("Using {} as bins for aspect lengths quantization".format(fbins))
logger.info("Count of instances per bin: {}".format(counts))
return groups
class GroupedBatchSampler(BatchSampler):
"""
Wraps another sampler to yield a mini-batch of indices.
It enforces that the batch only contain elements from the same group.
It also tries to provide mini-batches which follows an ordering which is
as close as possible to the ordering from the original sampler.
Arguments:
sampler (Sampler): Base sampler.
group_ids (list[int]): If the sampler produces indices in range [0, N),
`group_ids` must be a list of `N` ints which contains the group id of each sample.
The group ids must be a continuous set of integers starting from
0, i.e. they must be in the range [0, num_groups).
batch_size (int): Size of mini-batch.
"""
def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler):
raise ValueError(
"sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
self.sampler = sampler
self.group_ids = group_ids
self.batch_size = batch_size
def __iter__(self):
buffer_per_group = defaultdict(list)
samples_per_group = defaultdict(list)
num_batches = 0
for idx in self.sampler:
group_id = self.group_ids[idx]
buffer_per_group[group_id].append(idx)
samples_per_group[group_id].append(idx)
if len(buffer_per_group[group_id]) == self.batch_size:
yield buffer_per_group[group_id] #TODO
num_batches += 1
del buffer_per_group[group_id]
assert len(buffer_per_group[group_id]) < self.batch_size
# now we have run out of elements that satisfy
# the group criteria, let's return the remaining
# elements so that the size of the sampler is
# deterministic
expected_num_batches = len(self)
num_remaining = expected_num_batches - num_batches
if num_remaining > 0:
# for the remaining batches, group the batches by similar lengths
batch_idx = []
for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]):
batch_idx.extend(idxs)
if len(batch_idx) >= self.batch_size:
yield batch_idx[:self.batch_size]
batch_idx = batch_idx[self.batch_size:]
num_remaining -= 1
if len(batch_idx) > 0:
yield batch_idx
num_remaining -= 1
assert num_remaining == 0
def __len__(self):
"""
Return the number of mini-batches rather than the number of samples.
"""
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
...@@ -12,30 +12,33 @@ ...@@ -12,30 +12,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Dataloaders to train DistilBERT """ Dataset to distilled models
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
""" """
from typing import List
import math
from itertools import chain
from collections import Counter
import numpy as np
import torch import torch
from torch.utils.data import Dataset
import numpy as np
from utils import logger from utils import logger
class Dataset: class LmSeqsDataset(Dataset):
"""Custom Dataset wrapping language modeling sequences.
Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths.
Input:
------
params: `NameSpace` parameters
data: `List[np.array[int]]
"""
def __init__(self, def __init__(self,
params, params,
data): data):
self.params = params self.params = params
self.tokens_per_batch = params.tokens_per_batch
self.batch_size = params.batch_size
self.shuffle = params.shuffle
self.group_by_size = params.group_by_size
self.token_ids = np.array(data) self.token_ids = np.array(data)
self.lengths = np.uint16([len(t) for t in data]) self.lengths = np.array([len(t) for t in data])
self.check() self.check()
self.remove_long_sequences() self.remove_long_sequences()
...@@ -43,6 +46,9 @@ class Dataset: ...@@ -43,6 +46,9 @@ class Dataset:
self.check() self.check()
self.print_statistics() self.print_statistics()
def __getitem__(self, index):
return (self.token_ids[index], self.lengths[index])
def __len__(self): def __len__(self):
return len(self.lengths) return len(self.lengths)
...@@ -51,12 +57,14 @@ class Dataset: ...@@ -51,12 +57,14 @@ class Dataset:
Some sanity checks Some sanity checks
""" """
assert len(self.token_ids) == len(self.lengths) assert len(self.token_ids) == len(self.lengths)
assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths)))
def remove_long_sequences(self): def remove_long_sequences(self):
""" """
Sequences that are too long are splitted by chunk of max_position_embeddings. Sequences that are too long are splitted by chunk of max_model_input_size.
""" """
indices = self.lengths >= self.params.max_position_embeddings max_len = self.params.max_model_input_size
indices = self.lengths > max_len
logger.info(f'Splitting {sum(indices)} too long sequences.') logger.info(f'Splitting {sum(indices)} too long sequences.')
def divide_chunks(l, n): def divide_chunks(l, n):
...@@ -64,10 +72,13 @@ class Dataset: ...@@ -64,10 +72,13 @@ class Dataset:
new_tok_ids = [] new_tok_ids = []
new_lengths = [] new_lengths = []
cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token'] if self.params.mlm:
max_len = self.params.max_position_embeddings cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token']
else:
cls_id, sep_id = self.params.special_tok_ids['bos_token'], self.params.special_tok_ids['eos_token']
for seq_, len_ in zip(self.token_ids, self.lengths): for seq_, len_ in zip(self.token_ids, self.lengths):
assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
if len_ <= max_len: if len_ <= max_len:
new_tok_ids.append(seq_) new_tok_ids.append(seq_)
new_lengths.append(len_) new_lengths.append(len_)
...@@ -79,6 +90,7 @@ class Dataset: ...@@ -79,6 +90,7 @@ class Dataset:
if sub_s[-1] != sep_id: if sub_s[-1] != sep_id:
sub_s = np.insert(sub_s, len(sub_s), sep_id) sub_s = np.insert(sub_s, len(sub_s), sep_id)
assert len(sub_s) <= max_len assert len(sub_s) <= max_len
assert (sub_s[0] == cls_id) and (sub_s[-1] == sep_id), sub_s
sub_seqs.append(sub_s) sub_seqs.append(sub_s)
new_tok_ids.extend(sub_seqs) new_tok_ids.extend(sub_seqs)
...@@ -113,89 +125,27 @@ class Dataset: ...@@ -113,89 +125,27 @@ class Dataset:
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids]) # nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)') # logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
def select_data(self, a: int, b: int):
"""
Select a subportion of the data.
"""
n_sequences = len(self)
assert 0 <= a < b <= n_sequences, ValueError(f'`0 <= a < b <= n_sequences` is not met with a={a} and b={b}')
logger.info(f'Selecting sequences from {a} to {b} (excluded).')
self.token_ids = self.token_ids[a:b]
self.lengths = self.lengths[a:b]
self.check()
def split(self):
"""
Distributed training: split the data accross the processes.
"""
assert self.params.n_gpu > 1
logger.info('Splitting the data accross the processuses.')
n_seq = len(self)
n_seq_per_procesus = n_seq // self.params.world_size
a = n_seq_per_procesus * self.params.global_rank
b = a + n_seq_per_procesus
self.select_data(a=a, b=b)
def batch_sequences(self, def batch_sequences(self,
token_ids: List[List[int]], batch):
lengths: List[int]):
""" """
Do the padding and transform into torch.tensor. Do the padding and transform into torch.tensor.
""" """
token_ids = [t[0] for t in batch]
lengths = [t[1] for t in batch]
assert len(token_ids) == len(lengths) assert len(token_ids) == len(lengths)
# Max for paddings # Max for paddings
max_seq_len_ = max(lengths) max_seq_len_ = max(lengths)
# Pad token ids # Pad token ids
pad_idx = self.params.special_tok_ids['pad_token'] if self.params.mlm:
pad_idx = self.params.special_tok_ids['pad_token']
else:
pad_idx = self.params.special_tok_ids['unk_token']
tk_ = [list(t.astype(int)) + [pad_idx]*(max_seq_len_-len(t)) for t in token_ids] tk_ = [list(t.astype(int)) + [pad_idx]*(max_seq_len_-len(t)) for t in token_ids]
assert len(tk_) == len(token_ids) assert len(tk_) == len(token_ids)
assert all(len(t) == max_seq_len_ for t in tk_) assert all(len(t) == max_seq_len_ for t in tk_)
tk_t = torch.tensor(tk_) # (bs, max_seq_len_) tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
lg_t = torch.tensor(lengths.astype(int)) # (bs) lg_t = torch.tensor(lengths) # (bs)
return tk_t, lg_t return tk_t, lg_t
def get_batches_iterator(self,
batches):
"""
Return an iterator over batches.
"""
for sequences_ids in batches:
token_ids, lengths = self.batch_sequences(self.token_ids[sequences_ids],
self.lengths[sequences_ids])
yield (token_ids, lengths)
def get_iterator(self,
seed: int = None):
"""
Return a data iterator.
"""
rng = np.random.RandomState(seed)
n_sequences = len(self)
indices = np.arange(n_sequences)
if self.group_by_size:
indices = indices[np.argsort(self.lengths[indices], kind='mergesort')]
if self.tokens_per_batch == -1:
batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size))
else:
assert self.tokens_per_batch > 0
batch_ids = np.cumsum(self.lengths[indices]) // self.tokens_per_batch
_, bounds = np.unique(batch_ids, return_index=True)
batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)]
if bounds[-1] < len(indices):
batches.append(indices[bounds[-1]:])
if self.shuffle:
rng.shuffle(batches)
assert n_sequences == sum([len(x) for x in batches])
assert self.lengths[indices].sum() == sum([self.lengths[x].sum() for x in batches])
return self.get_batches_iterator(batches=batches)
...@@ -3,4 +3,4 @@ tensorboard>=1.14.0 ...@@ -3,4 +3,4 @@ tensorboard>=1.14.0
tensorboardX==1.8 tensorboardX==1.8
psutil==5.6.3 psutil==5.6.3
scipy==1.3.1 scipy==1.3.1
pytorch_transformers==1.2.0 transformers==2.0.0
This diff is collapsed.
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Preprocessing script before training DistilBERT. Preprocessing script before distillation.
""" """
import argparse import argparse
import pickle import pickle
import random import random
import time import time
import numpy as np import numpy as np
from transformers import BertTokenizer, RobertaTokenizer from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer
import logging import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
...@@ -32,7 +32,7 @@ def main(): ...@@ -32,7 +32,7 @@ def main():
parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).") parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).")
parser.add_argument('--file_path', type=str, default='data/dump.txt', parser.add_argument('--file_path', type=str, default='data/dump.txt',
help='The path to the data.') help='The path to the data.')
parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta']) parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta', 'gpt2'])
parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased', parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased',
help="The tokenizer to use.") help="The tokenizer to use.")
parser.add_argument('--dump_file', type=str, default='data/dump', parser.add_argument('--dump_file', type=str, default='data/dump',
...@@ -43,10 +43,16 @@ def main(): ...@@ -43,10 +43,16 @@ def main():
logger.info(f'Loading Tokenizer ({args.tokenizer_name})') logger.info(f'Loading Tokenizer ({args.tokenizer_name})')
if args.tokenizer_type == 'bert': if args.tokenizer_type == 'bert':
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name) tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['cls_token'] # `[CLS]`
sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]`
elif args.tokenizer_type == 'roberta': elif args.tokenizer_type == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['bos_token'] # `[CLS]` for bert, `<s>` for roberta bos = tokenizer.special_tokens_map['cls_token'] # `<s>`
sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]` for bert, `</s>` for roberta sep = tokenizer.special_tokens_map['sep_token'] # `</s>`
elif args.tokenizer_type == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['bos_token'] # `<|endoftext|>`
sep = tokenizer.special_tokens_map['eos_token'] # `<|endoftext|>`
logger.info(f'Loading text from {args.file_path}') logger.info(f'Loading text from {args.file_path}')
with open(args.file_path, 'r', encoding='utf8') as fp: with open(args.file_path, 'r', encoding='utf8') as fp:
......
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocessing script before training the distilled model.
Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2.
"""
from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel
import torch
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation")
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
parser.add_argument("--model_name", default='roberta-large', type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_roberta_048131723.pth', type=str)
parser.add_argument("--vocab_transform", action='store_true')
args = parser.parse_args()
if args.model_type == 'roberta':
model = RobertaForMaskedLM.from_pretrained(args.model_name)
prefix = 'roberta'
elif args.model_type == 'gpt2':
model = GPT2LMHeadModel.from_pretrained(args.model_name)
prefix = 'transformer'
state_dict = model.state_dict()
compressed_sd = {}
### Embeddings ###
if args.model_type == 'gpt2':
for param_name in ['wte.weight', 'wpe.weight']:
compressed_sd[f'{prefix}.{param_name}'] = state_dict[f'{prefix}.{param_name}']
else:
for w in ['word_embeddings', 'position_embeddings', 'token_type_embeddings']:
param_name = f'{prefix}.embeddings.{w}.weight'
compressed_sd[param_name] = state_dict[param_name]
for w in ['weight', 'bias']:
param_name = f'{prefix}.embeddings.LayerNorm.{w}'
compressed_sd[param_name] = state_dict[param_name]
### Transformer Blocks ###
std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]:
if args.model_type == 'gpt2':
for layer in ['ln_1', 'attn.c_attn', 'attn.c_proj', 'ln_2', 'mlp.c_fc', 'mlp.c_proj']:
for w in ['weight', 'bias']:
compressed_sd[f'{prefix}.h.{std_idx}.{layer}.{w}'] = \
state_dict[f'{prefix}.h.{teacher_idx}.{layer}.{w}']
compressed_sd[f'{prefix}.h.{std_idx}.attn.bias'] = state_dict[f'{prefix}.h.{teacher_idx}.attn.bias']
else:
for layer in ['attention.self.query', 'attention.self.key', 'attention.self.value',
'attention.output.dense', 'attention.output.LayerNorm',
'intermediate.dense', 'output.dense', 'output.LayerNorm']:
for w in ['weight', 'bias']:
compressed_sd[f'{prefix}.encoder.layer.{std_idx}.{layer}.{w}'] = \
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}']
std_idx += 1
### Language Modeling Head ###s
if args.model_type == 'roberta':
for layer in ['lm_head.decoder.weight', 'lm_head.bias']:
compressed_sd[f'{layer}'] = state_dict[f'{layer}']
if args.vocab_transform:
for w in ['weight', 'bias']:
compressed_sd[f'lm_head.dense.{w}'] = state_dict[f'lm_head.dense.{w}']
compressed_sd[f'lm_head.layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}']
elif args.model_type == 'gpt2':
for w in ['weight', 'bias']:
compressed_sd[f'{prefix}.ln_f.{w}'] = state_dict[f'{prefix}.ln_f.{w}']
compressed_sd[f'lm_head.weight'] = state_dict[f'lm_head.weight']
print(f'N layers selected for distillation: {std_idx}')
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
print(f'Save transfered checkpoint to {args.dump_checkpoint}.')
torch.save(compressed_sd, args.dump_checkpoint)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
Preprocessing script before training DistilBERT. Preprocessing script before training DistilBERT.
Specific to BERT -> DistilBERT.
""" """
from transformers import BertForMaskedLM, RobertaForMaskedLM from transformers import BertForMaskedLM, RobertaForMaskedLM
import torch import torch
...@@ -21,7 +22,7 @@ import argparse ...@@ -21,7 +22,7 @@ import argparse
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation") parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
parser.add_argument("--model_type", default="bert", choices=["bert", "roberta"]) parser.add_argument("--model_type", default="bert", choices=["bert"])
parser.add_argument("--model_name", default='bert-base-uncased', type=str) parser.add_argument("--model_name", default='bert-base-uncased', type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str) parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action='store_true')
...@@ -31,9 +32,8 @@ if __name__ == '__main__': ...@@ -31,9 +32,8 @@ if __name__ == '__main__':
if args.model_type == 'bert': if args.model_type == 'bert':
model = BertForMaskedLM.from_pretrained(args.model_name) model = BertForMaskedLM.from_pretrained(args.model_name)
prefix = 'bert' prefix = 'bert'
elif args.model_type == 'roberta': else:
model = RobertaForMaskedLM.from_pretrained(args.model_name) raise ValueError(f'args.model_type should be "bert".')
prefix = 'roberta'
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
...@@ -68,20 +68,12 @@ if __name__ == '__main__': ...@@ -68,20 +68,12 @@ if __name__ == '__main__':
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
std_idx += 1 std_idx += 1
if args.model_type == 'bert': compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] if args.vocab_transform:
if args.vocab_transform: for w in ['weight', 'bias']:
for w in ['weight', 'bias']: compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
elif args.model_type == 'roberta':
compressed_sd[f'vocab_projector.weight'] = state_dict[f'lm_head.decoder.weight']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'lm_head.bias']
if args.vocab_transform:
for w in ['weight', 'bias']:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'lm_head.dense.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}']
print(f'N layers selected for distillation: {std_idx}') print(f'N layers selected for distillation: {std_idx}')
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Preprocessing script before training DistilBERT. Preprocessing script before training the distilled model.
""" """
from collections import Counter from collections import Counter
import argparse import argparse
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Training DistilBERT. Training the distilled model.
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
""" """
import os import os
import argparse import argparse
...@@ -23,68 +24,96 @@ import shutil ...@@ -23,68 +24,96 @@ import shutil
import numpy as np import numpy as np
import torch import torch
from transformers import BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM from transformers import BertConfig, BertForMaskedLM, BertTokenizer
from transformers import DistilBertForMaskedLM, DistilBertConfig from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer
from transformers import DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from distiller import Distiller from distiller import Distiller
from utils import git_log, logger, init_gpu_params, set_seed from utils import git_log, logger, init_gpu_params, set_seed
from dataset import Dataset from lm_seqs_dataset import LmSeqsDataset
MODEL_CLASSES = {
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer)
}
def sanity_checks(args):
"""
A bunch of args sanity checks to perform even starting...
"""
assert (args.mlm and args.alpha_mlm > 0.) or (not args.mlm and args.alpha_mlm == 0.)
assert (args.alpha_mlm > 0. and args.alpha_clm == 0.) or (args.alpha_mlm == 0. and args.alpha_clm > 0.)
if args.mlm:
assert os.path.isfile(args.token_counts)
assert (args.student_type in ['roberta', 'distilbert']) and (args.teacher_type in ['roberta', 'bert'])
else:
assert (args.student_type in ['gpt2']) and (args.teacher_type in ['gpt2'])
assert args.teacher_type == args.student_type or (args.student_type=='distilbert' and args.teacher_type=='bert')
assert os.path.isfile(args.student_config)
if args.student_pretrained_weights is not None:
assert os.path.isfile(args.student_pretrained_weights)
if args.freeze_token_type_embds: assert args.student_type in ['roberta']
assert args.alpha_ce >= 0.
assert args.alpha_mlm >= 0.
assert args.alpha_clm >= 0.
assert args.alpha_mse >= 0.
assert args.alpha_cos >= 0.
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.
def freeze_pos_embeddings(student, args):
if args.student_type == 'roberta':
student.roberta.embeddings.position_embeddings.weight.requires_grad = False
elif args.student_type == 'gpt2':
student.transformer.wpe.weight.requires_grad = False
def freeze_token_type_embeddings(student, args):
if args.student_type == 'roberta':
student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
def main(): def main():
parser = argparse.ArgumentParser(description="Training") parser = argparse.ArgumentParser(description="Training")
parser.add_argument("--force", action='store_true',
help="Overwrite dump_path if it already exists.")
parser.add_argument("--dump_path", type=str, required=True, parser.add_argument("--dump_path", type=str, required=True,
help="The output directory (log, checkpoints, parameters, etc.)") help="The output directory (log, checkpoints, parameters, etc.)")
parser.add_argument("--data_file", type=str, required=True, parser.add_argument("--data_file", type=str, required=True,
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.") help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.")
parser.add_argument("--token_counts", type=str, required=True,
help="The token counts in the data_file for MLM.")
parser.add_argument("--force", action='store_true',
help="Overwrite dump_path if it already exists.")
parser.add_argument("--vocab_size", default=30522, type=int, parser.add_argument("--student_type", type=str, choices=["distilbert", "roberta", "gpt2"], required=True,
help="The vocabulary size.") help="The student type (DistilBERT, RoBERTa).")
parser.add_argument("--max_position_embeddings", default=512, type=int, parser.add_argument("--student_config", type=str, required=True,
help="Maximum sequence length we can model (including [CLS] and [SEP]).") help="Path to the student configuration.")
parser.add_argument("--sinusoidal_pos_embds", action='store_false', parser.add_argument("--student_pretrained_weights", default=None, type=str,
help="If true, the position embeddings are simply fixed with sinusoidal embeddings.")
parser.add_argument("--n_layers", default=6, type=int,
help="Number of Transformer blocks.")
parser.add_argument("--n_heads", default=12, type=int,
help="Number of heads in the self-attention module.")
parser.add_argument("--dim", default=768, type=int,
help="Dimension through the network. Must be divisible by n_heads")
parser.add_argument("--hidden_dim", default=3072, type=int,
help="Intermediate dimension in the FFN.")
parser.add_argument("--dropout", default=0.1, type=float,
help="Dropout.")
parser.add_argument("--attention_dropout", default=0.1, type=float,
help="Dropout in self-attention.")
parser.add_argument("--activation", default='gelu', type=str,
help="Activation to use in self-attention")
parser.add_argument("--tie_weights_", action='store_false',
help="If true, we tie the embeddings matrix with the projection over the vocabulary matrix. Default is true.")
parser.add_argument("--from_pretrained_weights", default=None, type=str,
help="Load student initialization checkpoint.") help="Load student initialization checkpoint.")
parser.add_argument("--from_pretrained_config", default=None, type=str,
help="Load student initialization architecture config.") parser.add_argument("--teacher_type", choices=["bert", "roberta", "gpt2"], required=True,
parser.add_argument("--teacher_type", default="bert", choices=["bert", "roberta"],
help="Teacher type (BERT, RoBERTa).") help="Teacher type (BERT, RoBERTa).")
parser.add_argument("--teacher_name", default="bert-base-uncased", type=str, parser.add_argument("--teacher_name", type=str, required=True,
help="The teacher model.") help="The teacher model.")
parser.add_argument("--temperature", default=2., type=float, parser.add_argument("--temperature", default=2., type=float,
help="Temperature for the softmax temperature.") help="Temperature for the softmax temperature.")
parser.add_argument("--alpha_ce", default=0.5, type=float, parser.add_argument("--alpha_ce", default=0.5, type=float,
help="Linear weight for the distillation loss. Must be >=0.") help="Linear weight for the distillation loss. Must be >=0.")
parser.add_argument("--alpha_mlm", default=0.5, type=float, parser.add_argument("--alpha_mlm", default=0.0, type=float,
help="Linear weight for the MLM loss. Must be >=0.") help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.")
parser.add_argument("--alpha_clm", default=0.5, type=float,
help="Linear weight for the CLM loss. Must be >=0.")
parser.add_argument("--alpha_mse", default=0.0, type=float, parser.add_argument("--alpha_mse", default=0.0, type=float,
help="Linear weight of the MSE loss. Must be >=0.") help="Linear weight of the MSE loss. Must be >=0.")
parser.add_argument("--alpha_cos", default=0.0, type=float, parser.add_argument("--alpha_cos", default=0.0, type=float,
help="Linear weight of the cosine embedding loss. Must be >=0.") help="Linear weight of the cosine embedding loss. Must be >=0.")
parser.add_argument("--mlm", action="store_true",
help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM.")
parser.add_argument("--mlm_mask_prop", default=0.15, type=float, parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
help="Proportion of tokens for which we need to make a prediction.") help="Proportion of tokens for which we need to make a prediction.")
parser.add_argument("--word_mask", default=0.8, type=float, parser.add_argument("--word_mask", default=0.8, type=float,
...@@ -95,17 +124,20 @@ def main(): ...@@ -95,17 +124,20 @@ def main():
help="Proportion of tokens to randomly replace.") help="Proportion of tokens to randomly replace.")
parser.add_argument("--mlm_smoothing", default=0.7, type=float, parser.add_argument("--mlm_smoothing", default=0.7, type=float,
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).") help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).")
parser.add_argument("--token_counts", type=str,
help="The token counts in the data_file for MLM.")
parser.add_argument("--restrict_ce_to_mask", action='store_true', parser.add_argument("--restrict_ce_to_mask", action='store_true',
help="If true, compute the distilation loss only the [MLM] prediction distribution.") help="If true, compute the distilation loss only the [MLM] prediction distribution.")
parser.add_argument("--freeze_pos_embs", action="store_true",
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.")
parser.add_argument("--freeze_token_type_embds", action="store_true",
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.")
parser.add_argument("--n_epoch", type=int, default=3, parser.add_argument("--n_epoch", type=int, default=3,
help="Number of pass on the whole dataset.") help="Number of pass on the whole dataset.")
parser.add_argument("--batch_size", type=int, default=5, parser.add_argument("--batch_size", type=int, default=5,
help="Batch size (for each process).") help="Batch size (for each process).")
parser.add_argument("--tokens_per_batch", type=int, default=-1,
help="If specified, modify the batches so that they have approximately this number of tokens.")
parser.add_argument("--shuffle", action='store_false',
help="If true, shuffle the sequence order. Default is true.")
parser.add_argument("--group_by_size", action='store_false', parser.add_argument("--group_by_size", action='store_false',
help="If true, group sequences that have similar length into the same batch. Default is true.") help="If true, group sequences that have similar length into the same batch. Default is true.")
...@@ -141,6 +173,7 @@ def main(): ...@@ -141,6 +173,7 @@ def main():
parser.add_argument("--checkpoint_interval", type=int, default=4000, parser.add_argument("--checkpoint_interval", type=int, default=4000,
help="Checkpoint interval.") help="Checkpoint interval.")
args = parser.parse_args() args = parser.parse_args()
sanity_checks(args)
## ARGS ## ## ARGS ##
...@@ -164,21 +197,19 @@ def main(): ...@@ -164,21 +197,19 @@ def main():
with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f: with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f:
json.dump(vars(args), f, indent=4) json.dump(vars(args), f, indent=4)
git_log(args.dump_path) git_log(args.dump_path)
assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \
(args.from_pretrained_weights is not None and args.from_pretrained_config is not None)
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
### TOKENIZER ### ### TOKENIZER ###
if args.teacher_type == 'bert': tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
tokenizer = BertTokenizer.from_pretrained(args.teacher_name)
elif args.teacher_type == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained(args.teacher_name)
special_tok_ids = {} special_tok_ids = {}
for tok_name, tok_symbol in tokenizer.special_tokens_map.items(): for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
idx = tokenizer.all_special_tokens.index(tok_symbol) idx = tokenizer.all_special_tokens.index(tok_symbol)
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx] special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
logger.info(f'Special tokens {special_tok_ids}') logger.info(f'Special tokens {special_tok_ids}')
args.special_tok_ids = special_tok_ids args.special_tok_ids = special_tok_ids
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
## DATA LOADER ## ## DATA LOADER ##
...@@ -187,35 +218,34 @@ def main(): ...@@ -187,35 +218,34 @@ def main():
data = pickle.load(fp) data = pickle.load(fp)
assert os.path.isfile(args.token_counts) if args.mlm:
logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)') logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)')
with open(args.token_counts, 'rb') as fp: with open(args.token_counts, 'rb') as fp:
counts = pickle.load(fp) counts = pickle.load(fp)
assert len(counts) == args.vocab_size
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
for idx in special_tok_ids.values(): for idx in special_tok_ids.values():
token_probs[idx] = 0. # do not predict special tokens token_probs[idx] = 0. # do not predict special tokens
token_probs = torch.from_numpy(token_probs) token_probs = torch.from_numpy(token_probs)
else:
token_probs = None
train_dataloader = Dataset(params=args, data=data) train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
logger.info(f'Data loader created.') logger.info(f'Data loader created.')
## STUDENT ## ## STUDENT ##
if args.from_pretrained_weights is not None: logger.info(f'Loading student config from {args.student_config}')
assert os.path.isfile(args.from_pretrained_weights) stu_architecture_config = student_config_class.from_pretrained(args.student_config)
assert os.path.isfile(args.from_pretrained_config) stu_architecture_config.output_hidden_states = True
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
logger.info(f'Loading pretrained config from {args.from_pretrained_config}') if args.student_pretrained_weights is not None:
stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config) logger.info(f'Loading pretrained weights from {args.student_pretrained_weights}')
stu_architecture_config.output_hidden_states = True student = student_model_class.from_pretrained(args.student_pretrained_weights,
student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights, config=stu_architecture_config)
config=stu_architecture_config)
else: else:
args.vocab_size_or_config_json_file = args.vocab_size student = student_model_class(stu_architecture_config)
stu_architecture_config = DistilBertConfig(**vars(args), output_hidden_states=True)
student = DistilBertForMaskedLM(stu_architecture_config)
if args.n_gpu > 0: if args.n_gpu > 0:
...@@ -224,18 +254,31 @@ def main(): ...@@ -224,18 +254,31 @@ def main():
## TEACHER ## ## TEACHER ##
if args.teacher_type == 'bert': teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
teacher = BertForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
elif args.teacher_type == 'roberta':
teacher = RobertaForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0: if args.n_gpu > 0:
teacher.to(f'cuda:{args.local_rank}') teacher.to(f'cuda:{args.local_rank}')
logger.info(f'Teacher loaded from {args.teacher_name}.') logger.info(f'Teacher loaded from {args.teacher_name}.')
## FREEZING ##
if args.freeze_pos_embs:
freeze_pos_embeddings(student, args)
if args.freeze_token_type_embds:
freeze_token_type_embeddings(student, args)
## SANITY CHECKS ##
assert student.config.vocab_size == teacher.config.vocab_size
assert student.config.hidden_size == teacher.config.hidden_size
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
if args.mlm:
assert token_probs.size(0) == stu_architecture_config.vocab_size
## DISTILLER ## ## DISTILLER ##
torch.cuda.empty_cache() torch.cuda.empty_cache()
distiller = Distiller(params=args, distiller = Distiller(params=args,
dataloader=train_dataloader, dataset=train_lm_seq_dataset,
token_probs=token_probs, token_probs=token_probs,
student=student, student=student,
teacher=teacher) teacher=teacher)
......
{
"activation": "gelu",
"attention_dropout": 0.1,
"dim": 768,
"dropout": 0.1,
"hidden_dim": 3072,
"initializer_range": 0.02,
"max_position_embeddings": 512,
"n_heads": 12,
"n_layers": 6,
"sinusoidal_pos_embds": true,
"tie_weights_": true,
"vocab_size": 30522
}
\ No newline at end of file
{
"initializer_range": 0.02,
"layer_norm_epsilon": 0.00001,
"n_ctx": 1024,
"n_embd": 768,
"n_head": 12,
"n_layer": 6,
"n_positions": 1024,
"vocab_size": 50257
}
\ No newline at end of file
tensorboardX tensorboardX
scikit-learn tensorboard
\ No newline at end of file scikit-learn
seqeval
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet) """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
""" """
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
...@@ -26,12 +26,14 @@ import torch ...@@ -26,12 +26,14 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig
from transformers import GPT2LMHeadModel, GPT2Tokenizer from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer from transformers import XLNetLMHeadModel, XLNetTokenizer
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
from transformers import CTRLLMHeadModel, CTRLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
...@@ -41,13 +43,15 @@ logger = logging.getLogger(__name__) ...@@ -41,13 +43,15 @@ logger = logging.getLogger(__name__)
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ()) ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer), 'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'xlnet': (XLNetLMHeadModel, XLNetTokenizer), 'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer), 'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
'xlm': (XLMWithLMHeadModel, XLMTokenizer),
} }
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
...@@ -103,7 +107,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') ...@@ -103,7 +107,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
return logits return logits
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'): def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
context = torch.tensor(context, dtype=torch.long, device=device) context = torch.tensor(context, dtype=torch.long, device=device)
context = context.unsqueeze(0).repeat(num_samples, 1) context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context generated = context
...@@ -121,10 +126,27 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= ...@@ -121,10 +126,27 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
target_mapping[0, 0, -1] = 1.0 # predict last token target_mapping[0, 0, -1] = 1.0 # predict last token
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) if is_xlm_mlm and xlm_mask_token:
next_token_logits = outputs[0][0, -1, :] / temperature # XLM MLM models are direct models (predict same token, not next token)
# => need one additional dummy token in the input (will be masked and guessed)
input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
inputs = {'input_ids': input_ids}
if xlm_lang is not None:
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.)
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for _ in set(generated.view(-1).tolist()):
next_token_logits[_] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) if temperature == 0: #greedy sampling:
next_token = torch.argmax(filtered_logits).unsqueeze(0)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
return generated return generated
...@@ -137,14 +159,20 @@ def main(): ...@@ -137,14 +159,20 @@ def main():
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--prompt", type=str, default="") parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--padding_text", type=str, default="") parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--length", type=int, default=20) parser.add_argument("--length", type=int, default=20)
parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--temperature", type=float, default=1.0,
help="temperature of 0 implies greedy sampling")
parser.add_argument("--repetition_penalty", type=float, default=1.0,
help="primarily useful for CTRL model; in that case, use 1.2")
parser.add_argument("--top_k", type=int, default=0) parser.add_argument("--top_k", type=int, default=0)
parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--no_cuda", action='store_true', parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available") help="Avoid using CUDA when available")
parser.add_argument('--seed', type=int, default=42, parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization") help="random seed for initialization")
parser.add_argument('--stop_token', type=str, default=None,
help="Token at which text generation is stopped")
args = parser.parse_args() args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
...@@ -166,8 +194,31 @@ def main(): ...@@ -166,8 +194,31 @@ def main():
elif args.length < 0: elif args.length < 0:
args.length = MAX_LENGTH # avoid infinite loop args.length = MAX_LENGTH # avoid infinite loop
print(args) logger.info(args)
if args.model_type in ["ctrl"]:
if args.temperature > 0.7 :
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
while True: while True:
xlm_lang = None
# XLM Language usage detailed in the issues #1414
if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id') and hasattr(model.config, 'use_lang_emb') \
and model.config.use_lang_emb:
if args.xlm_lang:
language = args.xlm_lang
else:
language = None
while language not in tokenizer.lang2id.keys():
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ")
xlm_lang = tokenizer.lang2id[language]
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
if is_xlm_mlm:
xlm_mask_token = tokenizer.mask_token_id
else:
xlm_mask_token = None
raw_text = args.prompt if args.prompt else input("Model prompt >>> ") raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
if args.model_type in ["transfo-xl", "xlnet"]: if args.model_type in ["transfo-xl", "xlnet"]:
# Models with memory likes to have a long prompt for short inputs. # Models with memory likes to have a long prompt for short inputs.
...@@ -180,11 +231,18 @@ def main(): ...@@ -180,11 +231,18 @@ def main():
temperature=args.temperature, temperature=args.temperature,
top_k=args.top_k, top_k=args.top_k,
top_p=args.top_p, top_p=args.top_p,
device=args.device, repetition_penalty=args.repetition_penalty,
is_xlnet=bool(args.model_type == "xlnet"), is_xlnet=bool(args.model_type == "xlnet"),
is_xlm_mlm=is_xlm_mlm,
xlm_mask_token=xlm_mask_token,
xlm_lang=xlm_lang,
device=args.device,
) )
out = out[0, len(context_tokens):].tolist() out = out[0, len(context_tokens):].tolist()
text = tokenizer.decode(out, clean_up_tokenization_spaces=True)
text = tokenizer.decode(out, clean_up_tokenization_spaces=True, skip_special_tokens=True)
text = text[: text.find(args.stop_token) if args.stop_token else None]
print(text) print(text)
if args.prompt: if args.prompt:
break break
......
...@@ -28,7 +28,12 @@ import torch ...@@ -28,7 +28,12 @@ import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset) TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
try:
from torch.utils.tensorboard import SummaryWriter
except:
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, BertConfig, from transformers import (WEIGHTS_NAME, BertConfig,
...@@ -53,7 +58,8 @@ from transformers import glue_convert_examples_to_features as convert_examples_t ...@@ -53,7 +58,8 @@ from transformers import glue_convert_examples_to_features as convert_examples_t
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig)), ()) ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig,
RobertaConfig, DistilBertConfig)), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer), 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
...@@ -154,7 +160,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -154,7 +160,7 @@ def train(args, train_dataset, model, tokenizer):
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu:
optimizer.step() optimizer.step()
scheduler.step() # Update learning rate schedule scheduler.step() # Update learning rate schedule
model.zero_grad() model.zero_grad()
...@@ -180,6 +186,11 @@ def train(args, train_dataset, model, tokenizer): ...@@ -180,6 +186,11 @@ def train(args, train_dataset, model, tokenizer):
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
if args.tpu:
args.xla_model.optimizer_step(optimizer, barrier=True)
model.zero_grad()
global_step += 1
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close() epoch_iterator.close()
break break
...@@ -248,7 +259,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -248,7 +259,7 @@ def evaluate(args, model, tokenizer, prefix=""):
result = compute_metrics(eval_task, preds, out_label_ids) result = compute_metrics(eval_task, preds, out_label_ids)
results.update(result) results.update(result)
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix)) logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()): for key in sorted(result.keys()):
...@@ -270,7 +281,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -270,7 +281,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
list(filter(None, args.model_name_or_path.split('/'))).pop(), list(filter(None, args.model_name_or_path.split('/'))).pop(),
str(args.max_seq_length), str(args.max_seq_length),
str(task))) str(task)))
if os.path.exists(cached_features_file): if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file) features = torch.load(cached_features_file)
else: else:
...@@ -379,6 +390,15 @@ def main(): ...@@ -379,6 +390,15 @@ def main():
parser.add_argument('--seed', type=int, default=42, parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization") help="random seed for initialization")
parser.add_argument('--tpu', action='store_true',
help="Whether to run on the TPU defined in the environment variables")
parser.add_argument('--tpu_ip_address', type=str, default='',
help="TPU IP address if none are set in the environment variables")
parser.add_argument('--tpu_name', type=str, default='',
help="TPU name if none are set in the environment variables")
parser.add_argument('--xrt_tpu_config', type=str, default='',
help="XRT TPU config if none are set in the environment variables")
parser.add_argument('--fp16', action='store_true', parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1', parser.add_argument('--fp16_opt_level', type=str, default='O1',
...@@ -412,6 +432,23 @@ def main(): ...@@ -412,6 +432,23 @@ def main():
args.n_gpu = 1 args.n_gpu = 1
args.device = device args.device = device
if args.tpu:
if args.tpu_ip_address:
os.environ["TPU_IP_ADDRESS"] = args.tpu_ip_address
if args.tpu_name:
os.environ["TPU_NAME"] = args.tpu_name
if args.xrt_tpu_config:
os.environ["XRT_TPU_CONFIG"] = args.xrt_tpu_config
assert "TPU_IP_ADDRESS" in os.environ
assert "TPU_NAME" in os.environ
assert "XRT_TPU_CONFIG" in os.environ
import torch_xla
import torch_xla.core.xla_model as xm
args.device = xm.xla_device()
args.xla_model = xm
# Setup logging # Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -457,7 +494,7 @@ def main(): ...@@ -457,7 +494,7 @@ def main():
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0) and not args.tpu:
# Create output directory if needed # Create output directory if needed
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir) os.makedirs(args.output_dir)
...@@ -489,9 +526,11 @@ def main(): ...@@ -489,9 +526,11 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
......
...@@ -27,12 +27,19 @@ import logging ...@@ -27,12 +27,19 @@ import logging
import os import os
import pickle import pickle
import random import random
import re
import shutil
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
try:
from torch.utils.tensorboard import SummaryWriter
except:
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule, from transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
...@@ -59,7 +66,7 @@ class TextDataset(Dataset): ...@@ -59,7 +66,7 @@ class TextDataset(Dataset):
def __init__(self, tokenizer, file_path='train', block_size=512): def __init__(self, tokenizer, file_path='train', block_size=512):
assert os.path.isfile(file_path) assert os.path.isfile(file_path)
directory, filename = os.path.split(file_path) directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(directory, 'cached_lm_{}_{}'.format(block_size, filename)) cached_features_file = os.path.join(directory, 'cached_lm_' + str(block_size) + '_' + filename)
if os.path.exists(cached_features_file): if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
...@@ -75,7 +82,7 @@ class TextDataset(Dataset): ...@@ -75,7 +82,7 @@ class TextDataset(Dataset):
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
for i in range(0, len(tokenized_text)-block_size+1, block_size): # Truncate in block of block_size for i in range(0, len(tokenized_text)-block_size+1, block_size): # Truncate in block of block_size
self.examples.append(tokenizer.add_special_tokens_single_sequence(tokenized_text[i:i+block_size])) self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i+block_size]))
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding) # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
# If your dataset is small, first you should loook for a bigger one :-) and second you # If your dataset is small, first you should loook for a bigger one :-) and second you
# can change this behavior by adding (model specific) padding. # can change this behavior by adding (model specific) padding.
...@@ -104,11 +111,43 @@ def set_seed(args): ...@@ -104,11 +111,43 @@ def set_seed(args):
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
if not args.save_total_limit:
return
if args.save_total_limit <= 0:
return
# Check if we should delete older checkpoint(s)
glob_checkpoints = glob.glob(os.path.join(args.output_dir, '{}-*'.format(checkpoint_prefix)))
if len(glob_checkpoints) <= args.save_total_limit:
return
ordering_and_checkpoint_path = []
for path in glob_checkpoints:
if use_mtime:
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
else:
regex_match = re.match('.*{}-([0-9]+)'.format(checkpoint_prefix), path)
if regex_match and regex_match.groups():
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted:
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
shutil.rmtree(checkpoint)
def mask_tokens(inputs, tokenizer, args): def mask_tokens(inputs, tokenizer, args):
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
labels = inputs.clone() labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).bool() probability_matrix = torch.full(labels.shape, args.mlm_probability)
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -1 # We only compute loss on masked tokens labels[~masked_indices] = -1 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
...@@ -222,8 +261,9 @@ def train(args, train_dataset, model, tokenizer): ...@@ -222,8 +261,9 @@ def train(args, train_dataset, model, tokenizer):
logging_loss = tr_loss logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
checkpoint_prefix = 'checkpoint'
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step))
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
...@@ -231,6 +271,8 @@ def train(args, train_dataset, model, tokenizer): ...@@ -231,6 +271,8 @@ def train(args, train_dataset, model, tokenizer):
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
_rotate_checkpoints(args, checkpoint_prefix)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close() epoch_iterator.close()
break break
...@@ -282,7 +324,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -282,7 +324,7 @@ def evaluate(args, model, tokenizer, prefix=""):
"perplexity": perplexity "perplexity": perplexity
} }
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix)) logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()): for key in sorted(result.keys()):
...@@ -359,6 +401,8 @@ def main(): ...@@ -359,6 +401,8 @@ def main():
help="Log every X updates steps.") help="Log every X updates steps.")
parser.add_argument('--save_steps', type=int, default=50, parser.add_argument('--save_steps', type=int, default=50,
help="Save checkpoint every X updates steps.") help="Save checkpoint every X updates steps.")
parser.add_argument('--save_total_limit', type=int, default=None,
help='Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default')
parser.add_argument("--eval_all_checkpoints", action='store_true', parser.add_argument("--eval_all_checkpoints", action='store_true',
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number") help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
parser.add_argument("--no_cuda", action='store_true', parser.add_argument("--no_cuda", action='store_true',
...@@ -484,9 +528,11 @@ def main(): ...@@ -484,9 +528,11 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment